org.mitre.oauth2.service.impl.DefaultOAuth2ClientDetailsEntityService.java Source code

Java tutorial

Introduction

Here is the source code for org.mitre.oauth2.service.impl.DefaultOAuth2ClientDetailsEntityService.java

Source

/*******************************************************************************
 * Copyright 2014 The MITRE Corporation
 *   and the MIT Kerberos and Internet Trust Consortium
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package org.mitre.oauth2.service.impl;

import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import org.apache.commons.codec.binary.Base64;
import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.repository.OAuth2ClientRepository;
import org.mitre.oauth2.repository.OAuth2TokenRepository;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.oauth2.service.SystemScopeService;
import org.mitre.openid.connect.model.WhitelistedSite;
import org.mitre.openid.connect.service.ApprovedSiteService;
import org.mitre.openid.connect.service.BlacklistedSiteService;
import org.mitre.openid.connect.service.StatsService;
import org.mitre.openid.connect.service.WhitelistedSiteService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.UncheckedExecutionException;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;

@Service
public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {

    private static Logger logger = LoggerFactory.getLogger(DefaultOAuth2ClientDetailsEntityService.class);

    @Autowired
    private OAuth2ClientRepository clientRepository;

    @Autowired
    private OAuth2TokenRepository tokenRepository;

    @Autowired
    private ApprovedSiteService approvedSiteService;

    @Autowired
    private WhitelistedSiteService whitelistedSiteService;

    @Autowired
    private BlacklistedSiteService blacklistedSiteService;

    @Autowired
    private SystemScopeService scopeService;

    @Autowired
    private StatsService statsService;

    // map of sector URI -> list of redirect URIs
    private LoadingCache<String, List<String>> sectorRedirects = CacheBuilder.newBuilder()
            .expireAfterAccess(1, TimeUnit.HOURS).maximumSize(100).build(new SectorIdentifierLoader());

    @Override
    public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
        if (client.getId() != null) { // if it's not null, it's already been saved, this is an error
            throw new IllegalArgumentException("Tried to save a new client with an existing ID: " + client.getId());
        }

        if (client.getRegisteredRedirectUri() != null) {
            for (String uri : client.getRegisteredRedirectUri()) {
                if (blacklistedSiteService.isBlacklisted(uri)) {
                    throw new IllegalArgumentException("Client URI is blacklisted: " + uri);
                }
            }
        }

        // assign a random clientid if it's empty
        // NOTE: don't assign a random client secret without asking, since public clients have no secret
        if (Strings.isNullOrEmpty(client.getClientId())) {
            client = generateClientId(client);
        }

        // if the client is flagged to allow for refresh tokens, make sure it's got the right granted scopes
        if (client.isAllowRefresh()) {
            client.getScope().add(SystemScopeService.OFFLINE_ACCESS);
        } else {
            client.getScope().remove(SystemScopeService.OFFLINE_ACCESS);
        }

        // timestamp this to right now
        client.setCreatedAt(new Date());

        // check the sector URI
        if (!Strings.isNullOrEmpty(client.getSectorIdentifierUri())) {
            try {
                List<String> redirects = sectorRedirects.get(client.getSectorIdentifierUri());

                if (client.getRegisteredRedirectUri() != null) {
                    for (String uri : client.getRegisteredRedirectUri()) {
                        if (!redirects.contains(uri)) {
                            throw new IllegalArgumentException("Requested Redirect URI " + uri
                                    + " is not listed at sector identifier " + redirects);
                        }
                    }
                }

            } catch (ExecutionException e) {
                throw new IllegalArgumentException(
                        "Unable to load sector identifier URI: " + client.getSectorIdentifierUri());
            }
        }

        // make sure a client doesn't get any special system scopes
        client.setScope(scopeService.removeRestrictedScopes(client.getScope()));

        ClientDetailsEntity c = clientRepository.saveClient(client);

        statsService.resetCache();

        return c;
    }

    /**
     * Get the client by its internal ID
     */
    @Override
    public ClientDetailsEntity getClientById(Long id) {
        ClientDetailsEntity client = clientRepository.getById(id);

        return client;
    }

    /**
     * Get the client for the given ClientID
     */
    @Override
    public ClientDetailsEntity loadClientByClientId(String clientId)
            throws OAuth2Exception, InvalidClientException, IllegalArgumentException {
        if (!Strings.isNullOrEmpty(clientId)) {
            ClientDetailsEntity client = clientRepository.getClientByClientId(clientId);
            if (client == null) {
                throw new InvalidClientException("Client with id " + clientId + " was not found");
            } else {
                return client;
            }
        }

        throw new IllegalArgumentException("Client id must not be empty!");
    }

    /**
     * Delete a client and all its associated tokens
     */
    @Override
    public void deleteClient(ClientDetailsEntity client) throws InvalidClientException {

        if (clientRepository.getById(client.getId()) == null) {
            throw new InvalidClientException("Client with id " + client.getClientId() + " was not found");
        }

        // clean out any tokens that this client had issued
        tokenRepository.clearTokensForClient(client);

        // clean out any approved sites for this client
        approvedSiteService.clearApprovedSitesForClient(client);

        // clear out any whitelisted sites for this client
        WhitelistedSite whitelistedSite = whitelistedSiteService.getByClientId(client.getClientId());
        if (whitelistedSite != null) {
            whitelistedSiteService.remove(whitelistedSite);
        }

        // take care of the client itself
        clientRepository.deleteClient(client);

        statsService.resetCache();

    }

    /**
     * Update the oldClient with information from the newClient. The
     * id from oldClient is retained.
     * 
     * Checks to make sure the refresh grant type and
     * the scopes are set appropriately.
     * 
     * Checks to make sure the redirect URIs aren't blacklisted.
     * 
     * Attempts to load the redirect URI (possibly cached) to check the
     * sector identifier against the contents there.
     * 
     * 
     */
    @Override
    public ClientDetailsEntity updateClient(ClientDetailsEntity oldClient, ClientDetailsEntity newClient)
            throws IllegalArgumentException {
        if (oldClient != null && newClient != null) {

            for (String uri : newClient.getRegisteredRedirectUri()) {
                if (blacklistedSiteService.isBlacklisted(uri)) {
                    throw new IllegalArgumentException("Client URI is blacklisted: " + uri);
                }
            }

            // if the client is flagged to allow for refresh tokens, make sure it's got the right scope
            if (newClient.isAllowRefresh()) {
                newClient.getScope().add(SystemScopeService.OFFLINE_ACCESS);
            } else {
                newClient.getScope().remove(SystemScopeService.OFFLINE_ACCESS);
            }

            // check the sector URI
            if (!Strings.isNullOrEmpty(newClient.getSectorIdentifierUri())) {
                try {
                    List<String> redirects = sectorRedirects.get(newClient.getSectorIdentifierUri());

                    if (newClient.getRegisteredRedirectUri() != null) {
                        for (String uri : newClient.getRegisteredRedirectUri()) {
                            if (!redirects.contains(uri)) {
                                throw new IllegalArgumentException("Requested Redirect URI " + uri
                                        + " is not listed at sector identifier " + redirects);
                            }
                        }
                    }
                } catch (UncheckedExecutionException ue) {
                    throw new IllegalArgumentException(
                            "Unable to load sector identifier URI: " + newClient.getSectorIdentifierUri());
                } catch (ExecutionException e) {
                    throw new IllegalArgumentException(
                            "Unable to load sector identifier URI: " + newClient.getSectorIdentifierUri());
                }
            }

            // make sure a client doesn't get any special system scopes
            newClient.setScope(scopeService.removeRestrictedScopes(newClient.getScope()));

            return clientRepository.updateClient(oldClient.getId(), newClient);
        }
        throw new IllegalArgumentException("Neither old client or new client can be null!");
    }

    /**
     * Get all clients in the system
     */
    @Override
    public Collection<ClientDetailsEntity> getAllClients() {
        return clientRepository.getAllClients();
    }

    /**
     * Generates a clientId for the given client and sets it to the client's clientId field. Returns the client that was passed in, now with id set.
     */
    @Override
    public ClientDetailsEntity generateClientId(ClientDetailsEntity client) {
        client.setClientId(UUID.randomUUID().toString());
        return client;
    }

    /**
     * Generates a new clientSecret for the given client and sets it to the client's clientSecret field. Returns the client that was passed in, now with secret set.
     */
    @Override
    public ClientDetailsEntity generateClientSecret(ClientDetailsEntity client) {
        client.setClientSecret(Base64
                .encodeBase64URLSafeString(new BigInteger(512, new SecureRandom()).toByteArray()).replace("=", ""));
        return client;
    }

    /**
     * Utility class to load a sector identifier's set of authorized redirect URIs.
     * 
     * @author jricher
     *
     */
    private class SectorIdentifierLoader extends CacheLoader<String, List<String>> {
        private HttpClient httpClient = new DefaultHttpClient();
        private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(
                httpClient);
        private RestTemplate restTemplate = new RestTemplate(httpFactory);
        private JsonParser parser = new JsonParser();

        @Override
        public List<String> load(String key) throws Exception {

            if (!key.startsWith("https")) {
                // TODO: this should optionally throw an error (#506)
                logger.error("Sector identifier doesn't start with https, loading anyway...");
            }

            // key is the sector URI
            String jsonString = restTemplate.getForObject(key, String.class);
            JsonElement json = parser.parse(jsonString);

            if (json.isJsonArray()) {
                List<String> redirectUris = new ArrayList<String>();
                for (JsonElement el : json.getAsJsonArray()) {
                    redirectUris.add(el.getAsString());
                }

                logger.info("Found " + redirectUris + " for sector " + key);

                return redirectUris;
            } else {
                return null;
            }

        }

    }

}