org.springframework.security.oauth.provider.filter.OAuthProviderProcessingFilter.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.security.oauth.provider.filter.OAuthProviderProcessingFilter.java

Source

/*
 * Copyright 2008 Web Cohesion
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.oauth.provider.filter;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.SpringSecurityMessageSource;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth.common.OAuthConsumerParameter;
import org.springframework.security.oauth.common.OAuthException;
import org.springframework.security.oauth.common.signature.*;
import org.springframework.security.oauth.provider.ConsumerAuthentication;
import org.springframework.security.oauth.provider.ConsumerCredentials;
import org.springframework.security.oauth.provider.ConsumerDetails;
import org.springframework.security.oauth.provider.ConsumerDetailsService;
import org.springframework.security.oauth.provider.InvalidOAuthParametersException;
import org.springframework.security.oauth.provider.OAuthAuthenticationDetails;
import org.springframework.security.oauth.provider.OAuthProcessingFilterEntryPoint;
import org.springframework.security.oauth.provider.OAuthProviderSupport;
import org.springframework.security.oauth.provider.OAuthVersionUnsupportedException;
import org.springframework.security.oauth.provider.nonce.ExpiringTimestampNonceServices;
import org.springframework.security.oauth.provider.nonce.OAuthNonceServices;
import org.springframework.security.oauth.provider.token.OAuthProviderToken;
import org.springframework.security.oauth.provider.token.OAuthProviderTokenServices;
import org.springframework.util.Assert;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * OAuth processing filter. This filter should be applied to requests for OAuth protected resources (see OAuth Core 1.0).
 *
 * @author Ryan Heaton
 */
public abstract class OAuthProviderProcessingFilter implements Filter, InitializingBean, MessageSourceAware {

    /**
     * Attribute for indicating that OAuth processing has already occurred.
     */
    public static final String OAUTH_PROCESSING_HANDLED = "org.springframework.security.oauth.provider.OAuthProviderProcessingFilter#SKIP_PROCESSING";

    private final Log log = LogFactory.getLog(getClass());
    private final List<String> allowedMethods = new ArrayList<String>(Arrays.asList("GET", "POST"));
    private OAuthProcessingFilterEntryPoint authenticationEntryPoint = new OAuthProcessingFilterEntryPoint();
    protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();
    private String filterProcessesUrl = "/oauth_filter";
    private OAuthProviderSupport providerSupport = new CoreOAuthProviderSupport();
    private OAuthSignatureMethodFactory signatureMethodFactory = new CoreOAuthSignatureMethodFactory();
    private OAuthNonceServices nonceServices = new ExpiringTimestampNonceServices();
    private boolean ignoreMissingCredentials = false;
    private OAuthProviderTokenServices tokenServices;

    private ConsumerDetailsService consumerDetailsService;

    public void afterPropertiesSet() throws Exception {
        Assert.notNull(consumerDetailsService, "A consumer details service is required.");
        Assert.notNull(tokenServices, "Token services are required.");
    }

    public void init(FilterConfig ignored) throws ServletException {
    }

    public void destroy() {
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        if (!skipProcessing(request)) {
            if (requiresAuthentication(request, response, chain)) {
                if (!allowMethod(request.getMethod().toUpperCase())) {
                    if (log.isDebugEnabled()) {
                        log.debug("Method " + request.getMethod() + " not supported.");
                    }

                    response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
                    return;
                }

                try {
                    Map<String, String> oauthParams = getProviderSupport().parseParameters(request);

                    if (parametersAreAdequate(oauthParams)) {

                        if (log.isDebugEnabled()) {
                            StringBuilder builder = new StringBuilder("OAuth parameters parsed: ");
                            for (String param : oauthParams.keySet()) {
                                builder.append(param).append('=').append(oauthParams.get(param)).append(' ');
                            }
                            log.debug(builder.toString());
                        }

                        String consumerKey = oauthParams.get(OAuthConsumerParameter.oauth_consumer_key.toString());
                        if (consumerKey == null) {
                            throw new InvalidOAuthParametersException(messages.getMessage(
                                    "OAuthProcessingFilter.missingConsumerKey", "Missing consumer key."));
                        }

                        //load the consumer details.
                        ConsumerDetails consumerDetails = getConsumerDetailsService()
                                .loadConsumerByConsumerKey(consumerKey);
                        if (log.isDebugEnabled()) {
                            log.debug("Consumer details loaded for " + consumerKey + ": " + consumerDetails);
                        }

                        //validate the parameters for the consumer.
                        validateOAuthParams(consumerDetails, oauthParams);
                        if (log.isDebugEnabled()) {
                            log.debug("Parameters validated.");
                        }

                        //extract the credentials.
                        String token = oauthParams.get(OAuthConsumerParameter.oauth_token.toString());
                        String signatureMethod = oauthParams
                                .get(OAuthConsumerParameter.oauth_signature_method.toString());
                        String signature = oauthParams.get(OAuthConsumerParameter.oauth_signature.toString());
                        String signatureBaseString = getProviderSupport().getSignatureBaseString(request);
                        ConsumerCredentials credentials = new ConsumerCredentials(consumerKey, signature,
                                signatureMethod, signatureBaseString, token);

                        //create an authentication request.
                        ConsumerAuthentication authentication = new ConsumerAuthentication(consumerDetails,
                                credentials, oauthParams);
                        authentication.setDetails(createDetails(request, consumerDetails));

                        Authentication previousAuthentication = SecurityContextHolder.getContext()
                                .getAuthentication();
                        try {
                            //set the authentication request (unauthenticated) into the context.
                            SecurityContextHolder.getContext().setAuthentication(authentication);

                            //validate the signature.
                            validateSignature(authentication);

                            //mark the authentication request as validated.
                            authentication.setSignatureValidated(true);

                            //mark that processing has been handled.
                            request.setAttribute(OAUTH_PROCESSING_HANDLED, Boolean.TRUE);

                            if (log.isDebugEnabled()) {
                                log.debug("Signature validated.");
                            }

                            //go.
                            onValidSignature(request, response, chain);
                        } finally {
                            //clear out the consumer authentication to make sure it doesn't get cached.
                            resetPreviousAuthentication(previousAuthentication);
                        }
                    } else if (!isIgnoreInadequateCredentials()) {
                        throw new InvalidOAuthParametersException(
                                messages.getMessage("OAuthProcessingFilter.missingCredentials",
                                        "Inadequate OAuth consumer credentials."));
                    } else {
                        if (log.isDebugEnabled()) {
                            log.debug("Supplied OAuth parameters are inadequate. Ignoring.");
                        }
                        chain.doFilter(request, response);
                    }
                } catch (AuthenticationException ae) {
                    fail(request, response, ae);
                } catch (ServletException e) {
                    if (e.getRootCause() instanceof AuthenticationException) {
                        fail(request, response, (AuthenticationException) e.getRootCause());
                    } else {
                        throw e;
                    }
                }
            } else {
                if (log.isDebugEnabled()) {
                    log.debug("Request does not require authentication.  OAuth processing skipped.");
                }

                chain.doFilter(servletRequest, servletResponse);
            }
        } else {
            if (log.isDebugEnabled()) {
                log.debug("Processing explicitly skipped.");
            }

            chain.doFilter(servletRequest, servletResponse);
        }
    }

    /**
     * By default, OAuth parameters are adequate if a consumer key is present.
     *
     * @param oauthParams The oauth params.
     * @return Whether the parsed parameters are adequate.
     */
    protected boolean parametersAreAdequate(Map<String, String> oauthParams) {
        return oauthParams.containsKey(OAuthConsumerParameter.oauth_consumer_key.toString());
    }

    protected void resetPreviousAuthentication(Authentication previousAuthentication) {
        SecurityContextHolder.getContext().setAuthentication(previousAuthentication);
    }

    /**
     * Create the details for the authentication request.
     *
     * @param request The request.
     * @param consumerDetails The consumer details.
     * @return The authentication details.
     */
    protected Object createDetails(HttpServletRequest request, ConsumerDetails consumerDetails) {
        return new OAuthAuthenticationDetails(request, consumerDetails);
    }

    /**
     * Whether to allow the specified HTTP method.
     *
     * @param method The HTTP method to check for allowing.
     * @return Whether to allow the specified method.
     */
    protected boolean allowMethod(String method) {
        return allowedMethods.contains(method);
    }

    /**
     * Validate the signature of the request given the authentication request.
     *
     * @param authentication The authentication request.
     */
    protected void validateSignature(ConsumerAuthentication authentication) throws AuthenticationException {
        SignatureSecret secret = authentication.getConsumerDetails().getSignatureSecret();
        String token = authentication.getConsumerCredentials().getToken();
        OAuthProviderToken authToken = null;
        if (token != null && !"".equals(token)) {
            authToken = getTokenServices().getToken(token);
        }

        String signatureMethod = authentication.getConsumerCredentials().getSignatureMethod();
        OAuthSignatureMethod method;
        try {
            method = getSignatureMethodFactory().getSignatureMethod(signatureMethod, secret,
                    authToken != null ? authToken.getSecret() : null);
        } catch (UnsupportedSignatureMethodException e) {
            throw new OAuthException(e.getMessage(), e);
        }

        String signatureBaseString = authentication.getConsumerCredentials().getSignatureBaseString();
        String signature = authentication.getConsumerCredentials().getSignature();
        if (log.isDebugEnabled()) {
            log.debug("Verifying signature " + signature + " for signature base string " + signatureBaseString
                    + " with method " + method.getName() + ".");
        }
        method.verify(signatureBaseString, signature);
    }

    /**
     * Logic executed on valid signature. The security context can be assumed to hold a verified, authenticated
     * {@link org.springframework.security.oauth.provider.ConsumerAuthentication}
     *
     * Default implementation continues the chain.
     *
     * @param request  The request.
     * @param response The response
     * @param chain    The filter chain.
     */
    protected abstract void onValidSignature(HttpServletRequest request, HttpServletResponse response,
            FilterChain chain) throws IOException, ServletException;

    /**
     * Validates the OAuth parameters for the given consumer. Base implementation validates only those parameters
     * that are required for all OAuth requests. This includes the nonce and timestamp, but not the signature.
     *
     * @param consumerDetails The consumer details.
     * @param oauthParams     The OAuth parameters to validate.
     * @throws InvalidOAuthParametersException If the OAuth parameters are invalid.
     */
    protected void validateOAuthParams(ConsumerDetails consumerDetails, Map<String, String> oauthParams)
            throws InvalidOAuthParametersException {
        String version = oauthParams.get(OAuthConsumerParameter.oauth_version.toString());
        if ((version != null) && (!"1.0".equals(version))) {
            throw new OAuthVersionUnsupportedException("Unsupported OAuth version: " + version);
        }

        String realm = oauthParams.get("realm");
        realm = realm == null || "".equals(realm) ? null : realm;
        if ((realm != null) && (!realm.equals(this.authenticationEntryPoint.getRealmName()))) {
            throw new InvalidOAuthParametersException(messages.getMessage("OAuthProcessingFilter.incorrectRealm",
                    new Object[] { realm, this.getAuthenticationEntryPoint().getRealmName() },
                    "Response realm name '{0}' does not match system realm name of '{1}'"));
        }

        String signatureMethod = oauthParams.get(OAuthConsumerParameter.oauth_signature_method.toString());
        if (signatureMethod == null) {
            throw new InvalidOAuthParametersException(messages
                    .getMessage("OAuthProcessingFilter.missingSignatureMethod", "Missing signature method."));
        }

        String signature = oauthParams.get(OAuthConsumerParameter.oauth_signature.toString());
        if (signature == null) {
            throw new InvalidOAuthParametersException(
                    messages.getMessage("OAuthProcessingFilter.missingSignature", "Missing signature."));
        }

        String timestamp = oauthParams.get(OAuthConsumerParameter.oauth_timestamp.toString());
        if (timestamp == null) {
            throw new InvalidOAuthParametersException(
                    messages.getMessage("OAuthProcessingFilter.missingTimestamp", "Missing timestamp."));
        }

        String nonce = oauthParams.get(OAuthConsumerParameter.oauth_nonce.toString());
        if (nonce == null) {
            throw new InvalidOAuthParametersException(
                    messages.getMessage("OAuthProcessingFilter.missingNonce", "Missing nonce."));
        }

        try {
            getNonceServices().validateNonce(consumerDetails, Long.parseLong(timestamp), nonce);
        } catch (NumberFormatException e) {
            throw new InvalidOAuthParametersException(messages.getMessage("OAuthProcessingFilter.invalidTimestamp",
                    new Object[] { timestamp }, "Timestamp must be a positive integer. Invalid value: {0}"));
        }

        validateAdditionalParameters(consumerDetails, oauthParams);
    }

    /**
     * Do any additional validation checks for the specified oauth params.  Default implementation is a no-op.
     *
     * @param consumerDetails The consumer details.
     * @param oauthParams The params.
     */
    protected void validateAdditionalParameters(ConsumerDetails consumerDetails, Map<String, String> oauthParams) {
    }

    /**
     * Logic to be performed on a new timestamp.  The default behavior expects that the timestamp should not be new.
     *
     * @throws org.springframework.security.core.AuthenticationException
     *          If the timestamp shouldn't be new.
     */
    protected void onNewTimestamp() throws AuthenticationException {
        throw new InvalidOAuthParametersException(messages.getMessage("OAuthProcessingFilter.timestampNotNew",
                "A new timestamp should not be used in a request for an access token."));
    }

    /**
     * Common logic for OAuth failed.
     *
     * @param request  The request.
     * @param response The response.
     * @param failure  The failure.
     * @throws IOException thrown when there's an underlying IO exception
     * @throws ServletException thrown in the case of an underlying Servlet exception 
     */
    protected void fail(HttpServletRequest request, HttpServletResponse response, AuthenticationException failure)
            throws IOException, ServletException {
        SecurityContextHolder.getContext().setAuthentication(null);

        if (log.isDebugEnabled()) {
            log.debug(failure);
        }

        authenticationEntryPoint.commence(request, response, failure);
    }

    /**
     * Whether this filter is configured to process the specified request.
     *
     * @param request     The request.
     * @param response    The response
     * @param filterChain The filter chain
     * @return Whether this filter is configured to process the specified request.
     */
    protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response,
            FilterChain filterChain) {
        //copied from org.springframework.security.ui.AbstractProcessingFilter.requiresAuthentication
        String uri = request.getRequestURI();
        int pathParamIndex = uri.indexOf(';');

        if (pathParamIndex > 0) {
            // strip everything after the first semi-colon
            uri = uri.substring(0, pathParamIndex);
        }

        if ("".equals(request.getContextPath())) {
            return uri.endsWith(filterProcessesUrl);
        }

        boolean match = uri.endsWith(request.getContextPath() + filterProcessesUrl);
        if (log.isDebugEnabled()) {
            log.debug(uri + (match ? " matches " : " does not match ") + filterProcessesUrl);
        }
        return match;
    }

    /**
     * Whether to skip processing for the specified request.
     *
     * @param request The request.
     * @return Whether to skip processing.
     */
    protected boolean skipProcessing(HttpServletRequest request) {
        return ((request.getAttribute(OAUTH_PROCESSING_HANDLED) != null)
                && (Boolean.TRUE.equals(request.getAttribute(OAUTH_PROCESSING_HANDLED))));
    }

    /**
     * The authentication entry point.
     *
     * @return The authentication entry point.
     */
    public OAuthProcessingFilterEntryPoint getAuthenticationEntryPoint() {
        return authenticationEntryPoint;
    }

    /**
     * The authentication entry point.
     *
     * @param authenticationEntryPoint The authentication entry point.
     */
    @Autowired(required = false)
    public void setAuthenticationEntryPoint(OAuthProcessingFilterEntryPoint authenticationEntryPoint) {
        this.authenticationEntryPoint = authenticationEntryPoint;
    }

    /**
     * The consumer details service.
     *
     * @return The consumer details service.
     */
    public ConsumerDetailsService getConsumerDetailsService() {
        return consumerDetailsService;
    }

    /**
     * The consumer details service.
     *
     * @param consumerDetailsService The consumer details service.
     */
    @Autowired
    public void setConsumerDetailsService(ConsumerDetailsService consumerDetailsService) {
        this.consumerDetailsService = consumerDetailsService;
    }

    /**
     * The nonce services.
     *
     * @return The nonce services.
     */
    public OAuthNonceServices getNonceServices() {
        return nonceServices;
    }

    /**
     * The nonce services.
     *
     * @param nonceServices The nonce services.
     */
    @Autowired(required = false)
    public void setNonceServices(OAuthNonceServices nonceServices) {
        this.nonceServices = nonceServices;
    }

    /**
     * Get the OAuth token services.
     *
     * @return The OAuth token services.
     */
    public OAuthProviderTokenServices getTokenServices() {
        return tokenServices;
    }

    /**
     * The OAuth token services.
     *
     * @param tokenServices The OAuth token services.
     */
    @Autowired
    public void setTokenServices(OAuthProviderTokenServices tokenServices) {
        this.tokenServices = tokenServices;
    }

    /**
     * The URL for which this filter will be applied.
     *
     * @return The URL for which this filter will be applied.
     */
    public String getFilterProcessesUrl() {
        return filterProcessesUrl;
    }

    /**
     * The URL for which this filter will be applied.
     *
     * @param filterProcessesUrl The URL for which this filter will be applied.
     */
    public void setFilterProcessesUrl(String filterProcessesUrl) {
        this.filterProcessesUrl = filterProcessesUrl;
    }

    /**
     * Set the message source.
     *
     * @param messageSource The message source.
     */
    public void setMessageSource(MessageSource messageSource) {
        this.messages = new MessageSourceAccessor(messageSource);
    }

    /**
     * The OAuth provider support.
     *
     * @return The OAuth provider support.
     */
    public OAuthProviderSupport getProviderSupport() {
        return providerSupport;
    }

    /**
     * The OAuth provider support.
     *
     * @param providerSupport The OAuth provider support.
     */
    @Autowired(required = false)
    public void setProviderSupport(OAuthProviderSupport providerSupport) {
        this.providerSupport = providerSupport;
    }

    /**
     * The OAuth signature method factory.
     *
     * @return The OAuth signature method factory.
     */
    public OAuthSignatureMethodFactory getSignatureMethodFactory() {
        return signatureMethodFactory;
    }

    /**
     * The OAuth signature method factory.
     *
     * @param signatureMethodFactory The OAuth signature method factory.
     */
    @Autowired(required = false)
    public void setSignatureMethodFactory(OAuthSignatureMethodFactory signatureMethodFactory) {
        this.signatureMethodFactory = signatureMethodFactory;
    }

    /**
     * Whether to ignore missing OAuth credentials.
     *
     * @return Whether to ignore missing OAuth credentials.
     */
    public boolean isIgnoreInadequateCredentials() {
        return ignoreMissingCredentials;
    }

    /**
     * Whether to ignore missing OAuth credentials.
     *
     * @param ignoreMissingCredentials Whether to ignore missing OAuth credentials.
     */
    public void setIgnoreMissingCredentials(boolean ignoreMissingCredentials) {
        this.ignoreMissingCredentials = ignoreMissingCredentials;
    }

    /**
     * The allowed set of HTTP methods.
     *
     * @param allowedMethods The allowed set of methods.
     */
    public void setAllowedMethods(List<String> allowedMethods) {
        this.allowedMethods.clear();
        if (allowedMethods != null) {
            for (String allowedMethod : allowedMethods) {
                this.allowedMethods.add(allowedMethod.toUpperCase());
            }
        }
    }
}