org.opentestsystem.shared.security.oauth.client.grant.samlbearer.SamlAssertionAccessTokenProvider.java Source code

Java tutorial

Introduction

Here is the source code for org.opentestsystem.shared.security.oauth.client.grant.samlbearer.SamlAssertionAccessTokenProvider.java

Source

/*******************************************************************************
 * Educational Online Test Delivery System 
 * Copyright (c) 2014 American Institutes for Research
 *   
 * Distributed under the AIR Open Source License, Version 1.0 
 * See accompanying file AIR-License-1_0.txt or at
 * https://bitbucket.org/sbacoss/eotds/wiki/AIR_Open_Source_License
 ******************************************************************************/
package org.opentestsystem.shared.security.oauth.client.grant.samlbearer;

import java.io.StringWriter;

import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.apache.commons.ssl.Base64;
import org.opensaml.saml2.core.Assertion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.resource.BaseOAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
import org.springframework.security.oauth2.client.token.AccessTokenProvider;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.client.token.OAuth2AccessTokenSupport;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.saml.SAMLCredential;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

public class SamlAssertionAccessTokenProvider extends OAuth2AccessTokenSupport implements AccessTokenProvider {
    public static final String SAML2_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:saml2-bearer";
    private static final Logger LOGGER = LoggerFactory.getLogger(SamlAssertionAccessTokenProvider.class);

    public String getSamlAssertion() {
        String encodedStr = "";
        try {
            final Authentication auth = SecurityContextHolder.getContext().getAuthentication();
            final SAMLCredential cred = (SAMLCredential) auth.getCredentials();
            final Assertion assertion = cred.getAuthenticationAssertion();

            final StringWriter output = new StringWriter();

            final Transformer transformer = TransformerFactory.newInstance().newTransformer();
            transformer.transform(new DOMSource(assertion.getDOM()), new StreamResult(output));

            String xml = output.toString();
            LOGGER.info("SAML ASSERTION:" + xml);
            byte[] bytesEncoded = Base64.encodeBase64(xml.getBytes());
            encodedStr = new String(bytesEncoded);

            LOGGER.info("SAML encoded:" + encodedStr);
        } catch (final TransformerException e) {
            LOGGER.error("There was an issue processing the SAML assertion", e);
        }
        return encodedStr;
    }

    @Override
    public boolean supportsResource(final OAuth2ProtectedResourceDetails resource) {
        return SamlAssertionResourceDetails.GRANT_TYPE_SAML_ASSERTION.equals(resource.getGrantType());
    }

    @Override
    public boolean supportsRefresh(final OAuth2ProtectedResourceDetails resource) {
        // return supportsResource(resource);
        return false;
    }

    @Override
    public OAuth2AccessToken refreshAccessToken(final OAuth2ProtectedResourceDetails resource,
            final OAuth2RefreshToken refreshToken, final AccessTokenRequest request)
            throws UserRedirectRequiredException, OAuth2AccessDeniedException {
        // MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
        // form.add("grant_type", "refresh_token");
        // form.add("refresh_token", refreshToken.getValue());
        // return retrieveToken(request, resource, form, new HttpHeaders());
        return null;
    }

    @Override
    public OAuth2AccessToken obtainAccessToken(final OAuth2ProtectedResourceDetails details,
            final AccessTokenRequest request)
            throws UserRedirectRequiredException, AccessDeniedException, OAuth2AccessDeniedException {
        final BaseOAuth2ProtectedResourceDetails resource = (BaseOAuth2ProtectedResourceDetails) details;
        final String assertion = getSamlAssertion();
        return retrieveToken(request, resource, getParametersForTokenRequest(resource, assertion),
                new HttpHeaders());
    }

    private MultiValueMap<String, String> getParametersForTokenRequest(
            final BaseOAuth2ProtectedResourceDetails resource, final String assertion) {
        final MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
        form.set("grant_type", SAML2_BEARER_GRANT_TYPE);
        form.set("assertion", assertion);
        form.set("client_id", resource.getClientId());
        LOGGER.info("YEAH... " + resource.getClientId());
        if (resource.isScoped()) {
            final String scopeString = resource.getScope() != null
                    ? StringUtils.collectionToDelimitedString(resource.getScope(), " ")
                    : "";
            form.set("scope", scopeString);
            LOGGER.info("YEAH... scope " + scopeString);
        }

        return form;

    }
}