io.gravitee.policy.oauth2.Oauth2Policy.java Source code

Java tutorial

Introduction

Here is the source code for io.gravitee.policy.oauth2.Oauth2Policy.java

Source

/**
 * Copyright (C) 2015 The Gravitee team (http://gravitee.io)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.gravitee.policy.oauth2;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import io.gravitee.common.http.HttpHeaders;
import io.gravitee.common.http.HttpStatusCode;
import io.gravitee.common.http.MediaType;
import io.gravitee.gateway.api.ExecutionContext;
import io.gravitee.gateway.api.Request;
import io.gravitee.gateway.api.Response;
import io.gravitee.gateway.api.handler.Handler;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyResult;
import io.gravitee.policy.api.annotations.OnRequest;
import io.gravitee.policy.oauth2.configuration.OAuth2PolicyConfiguration;
import io.gravitee.resource.api.ResourceManager;
import io.gravitee.resource.oauth2.api.OAuth2Resource;
import io.gravitee.resource.oauth2.api.OAuth2Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import java.io.IOException;
import java.util.*;

/**
 * @author David BRASSELY (david.brassely at graviteesource.com)
 * @author Titouan COMPIEGNE (titouan.compiegne at graviteesource.com)
 * @author GraviteeSource Team
 */
public class Oauth2Policy {

    private final Logger logger = LoggerFactory.getLogger(Oauth2Policy.class);

    static final String BEARER_AUTHORIZATION_TYPE = "Bearer";
    static final String OAUTH_PAYLOAD_SCOPE_NODE = "scope";
    static final String OAUTH_PAYLOAD_CLIENT_ID_NODE = "client_id";

    static final String CONTEXT_ATTRIBUTE_PREFIX = "oauth.";
    static final String CONTEXT_ATTRIBUTE_OAUTH_PAYLOAD = CONTEXT_ATTRIBUTE_PREFIX + "payload";
    static final String CONTEXT_ATTRIBUTE_OAUTH_ACCESS_TOKEN = CONTEXT_ATTRIBUTE_PREFIX + "access_token";
    static final String CONTEXT_ATTRIBUTE_CLIENT_ID = CONTEXT_ATTRIBUTE_PREFIX + "client_id";

    static final ObjectMapper MAPPER = new ObjectMapper();

    private OAuth2PolicyConfiguration oAuth2PolicyConfiguration;

    public Oauth2Policy(OAuth2PolicyConfiguration oAuth2PolicyConfiguration) {
        this.oAuth2PolicyConfiguration = oAuth2PolicyConfiguration;
    }

    @OnRequest
    public void onRequest(Request request, Response response, ExecutionContext executionContext,
            PolicyChain policyChain) {
        logger.debug("Read access_token from request {}", request.id());

        OAuth2Resource oauth2 = executionContext.getComponent(ResourceManager.class)
                .getResource(oAuth2PolicyConfiguration.getOauthResource(), OAuth2Resource.class);

        if (oauth2 == null) {
            policyChain.failWith(PolicyResult.failure(HttpStatusCode.UNAUTHORIZED_401,
                    "No OAuth authorization server has been configured"));
            return;
        }

        List<String> authorizationHeaders = request.headers().get(HttpHeaders.AUTHORIZATION);

        if (authorizationHeaders == null || authorizationHeaders.isEmpty()) {
            sendError(response, policyChain, "invalid_request", "No OAuth authorization header was supplied");
            return;
        }

        Optional<String> optionalHeaderAccessToken = authorizationHeaders.stream()
                .filter(h -> StringUtils.startsWithIgnoreCase(h, BEARER_AUTHORIZATION_TYPE)).findFirst();
        if (!optionalHeaderAccessToken.isPresent()) {
            sendError(response, policyChain, "invalid_request", "No OAuth authorization header was supplied");
            return;
        }

        String accessToken = optionalHeaderAccessToken.get().substring(BEARER_AUTHORIZATION_TYPE.length()).trim();
        if (accessToken.isEmpty()) {
            sendError(response, policyChain, "invalid_request", "No OAuth access token was supplied");
            return;
        }

        // Set access_token in context
        executionContext.setAttribute(CONTEXT_ATTRIBUTE_OAUTH_ACCESS_TOKEN, accessToken);

        // Validate access token
        oauth2.introspect(accessToken, handleResponse(policyChain, request, response, executionContext));
    }

    Handler<OAuth2Response> handleResponse(PolicyChain policyChain, Request request, Response response,
            ExecutionContext executionContext) {
        return oauth2response -> {
            if (oauth2response.isSuccess()) {
                JsonNode oauthResponseNode = readPayload(oauth2response.getPayload());

                if (oauthResponseNode == null) {
                    sendError(response, policyChain, "server_error", "Invalid response from authorization server");
                    return;
                }

                // Extract client_id
                String clientId = oauthResponseNode.path(OAUTH_PAYLOAD_CLIENT_ID_NODE).asText();
                if (clientId != null && !clientId.trim().isEmpty()) {
                    executionContext.setAttribute(CONTEXT_ATTRIBUTE_CLIENT_ID, clientId);
                }

                // Check required scopes to access the resource
                if (oAuth2PolicyConfiguration.isCheckRequiredScopes()) {
                    OAuth2Resource oauth2 = executionContext.getComponent(ResourceManager.class)
                            .getResource(oAuth2PolicyConfiguration.getOauthResource(), OAuth2Resource.class);

                    if (!hasRequiredScopes(oauthResponseNode, oAuth2PolicyConfiguration.getRequiredScopes(),
                            oauth2.getScopeSeparator())) {
                        sendError(response, policyChain, "insufficient_scope",
                                "The request requires higher privileges than provided by the access token.");
                        return;
                    }
                }

                // Store OAuth2 payload into execution context if required
                if (oAuth2PolicyConfiguration.isExtractPayload()) {
                    executionContext.setAttribute(CONTEXT_ATTRIBUTE_OAUTH_PAYLOAD, oauth2response.getPayload());
                }

                // Continue chaining
                policyChain.doNext(request, response);
            } else {
                response.headers().add(HttpHeaders.WWW_AUTHENTICATE,
                        BEARER_AUTHORIZATION_TYPE + " realm=gravitee.io ");

                if (oauth2response.getThrowable() == null) {
                    policyChain.failWith(PolicyResult.failure(HttpStatusCode.UNAUTHORIZED_401,
                            oauth2response.getPayload(), MediaType.APPLICATION_JSON));
                } else {
                    policyChain.failWith(PolicyResult.failure(HttpStatusCode.SERVICE_UNAVAILABLE_503,
                            "temporarily_unavailable"));
                }
            }
        };
    }

    /**
     * As per https://tools.ietf.org/html/rfc6750#page-7:
     *
     *      HTTP/1.1 401 Unauthorized
     *      WWW-Authenticate: Bearer realm="example",
     *      error="invalid_token",
     *      error_description="The access token expired"
     */
    private void sendError(Response response, PolicyChain policyChain, String error, String description) {
        String headerValue = BEARER_AUTHORIZATION_TYPE + " realm=\"gravitee.io\"," + " error=\"" + error + "\","
                + " error_description=\"" + description + "\"";
        response.headers().add(HttpHeaders.WWW_AUTHENTICATE, headerValue);
        policyChain.failWith(PolicyResult.failure(HttpStatusCode.UNAUTHORIZED_401, null));
    }

    private JsonNode readPayload(String oauthPayload) {
        try {
            return MAPPER.readTree(oauthPayload);
        } catch (IOException ioe) {
            logger.error("Unable to check required scope from introspection endpoint payload: {}", oauthPayload);
            return null;
        }
    }

    static boolean hasRequiredScopes(JsonNode oauthResponseNode, List<String> requiredScopes,
            String scopeSeparator) {
        if (requiredScopes == null) {
            return true;
        }

        JsonNode scopesNode = oauthResponseNode.path(OAUTH_PAYLOAD_SCOPE_NODE);

        List<String> scopes;
        if (scopesNode instanceof ArrayNode) {
            Iterator<JsonNode> scopeIterator = scopesNode.elements();
            scopes = new ArrayList<>(scopesNode.size());
            List<String> finalScopes = scopes;
            scopeIterator.forEachRemaining(jsonNode -> finalScopes.add(jsonNode.asText()));
        } else {
            scopes = Arrays.asList(scopesNode.asText().split(scopeSeparator));
        }

        return scopes.containsAll(requiredScopes);
    }
}