ddf.security.realm.sts.AbstractStsRealm.java Source code

Java tutorial

Introduction

Here is the source code for ddf.security.realm.sts.AbstractStsRealm.java

Source

/**
 * Copyright (c) Codice Foundation
 * <p>
 * This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
 * General Public License as published by the Free Software Foundation, either version 3 of the
 * License, or any later version.
 * <p>
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
 * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details. A copy of the GNU Lesser General Public License
 * is distributed along with this program and can be found at
 * <http://www.gnu.org/licenses/lgpl.html>.
 */
package ddf.security.realm.sts;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import javax.xml.stream.XMLStreamException;

import org.apache.cxf.Bus;
import org.apache.cxf.BusFactory;
import org.apache.cxf.bus.CXFBusFactory;
import org.apache.cxf.bus.spring.SpringBusFactory;
import org.apache.cxf.staxutils.W3CDOMStreamWriter;
import org.apache.cxf.ws.security.SecurityConstants;
import org.apache.cxf.ws.security.tokenstore.SecurityToken;
import org.apache.cxf.ws.security.trust.STSClient;
import org.apache.cxf.ws.security.trust.STSUtils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authc.credential.CredentialsMatcher;
import org.apache.shiro.realm.AuthenticatingRealm;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.codice.ddf.configuration.PropertyResolver;
import org.codice.ddf.security.handler.api.BaseAuthenticationToken;
import org.codice.ddf.security.handler.api.SAMLAuthenticationToken;
import org.codice.ddf.security.policy.context.ContextPolicy;
import org.codice.ddf.security.policy.context.ContextPolicyManager;
import org.slf4j.LoggerFactory;
import org.slf4j.ext.XLogger;
import org.w3c.dom.DOMImplementation;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.ls.DOMImplementationLS;
import org.w3c.dom.ls.LSSerializer;

import com.google.common.base.Splitter;

import ddf.security.PropertiesLoader;
import ddf.security.assertion.SecurityAssertion;
import ddf.security.assertion.impl.SecurityAssertionImpl;
import ddf.security.common.audit.SecurityLogger;
import ddf.security.sts.client.configuration.STSClientConfiguration;

public abstract class AbstractStsRealm extends AuthenticatingRealm implements STSClientConfiguration {
    private static final XLogger LOGGER = new XLogger(LoggerFactory.getLogger(AbstractStsRealm.class));

    private static final String NAME = AbstractStsRealm.class.getSimpleName();

    private static final String ADDRESSING_NAMESPACE = "http://www.w3.org/2005/08/addressing";

    private static final Splitter SPLITTER = Splitter.on(',').trimResults().omitEmptyStrings();

    protected Bus bus;

    PropertyResolver address = null;

    String endpointName = null;

    String serviceName = null;

    String username = null;

    String password = null;

    String signatureUsername = null;

    String signatureProperties = null;

    String encryptionUsername = null;

    String encryptionProperties = null;

    String tokenUsername = null;

    String tokenProperties = null;

    List<String> claims = new ArrayList<>();

    private STSClient stsClient;

    private boolean settingsConfigured;

    private ContextPolicyManager contextPolicyManager;

    private String assertionType = null;

    private String keyType = null;

    private String keySize = null;

    private Boolean useKey = null;

    public AbstractStsRealm() {
        this.bus = getBus();
        setCredentialsMatcher(new STSCredentialsMatcher());
    }

    public ContextPolicyManager getContextPolicyManager() {
        return contextPolicyManager;
    }

    public void setContextPolicyManager(ContextPolicyManager contextPolicyManager) {
        this.contextPolicyManager = contextPolicyManager;
    }

    /**
     * Determine if the supplied token is supported by this realm.
     */
    @Override
    public boolean supports(AuthenticationToken token) {
        boolean supported = token != null && token.getCredentials() != null;
        if (token instanceof BaseAuthenticationToken) {
            supported = supported && ((BaseAuthenticationToken) token).isUseWssSts() == shouldHandleWss();
        }

        if (supported) {
            LOGGER.debug("Token {} is supported by {}.", token.getClass(), AbstractStsRealm.class.getName());
        } else if (token != null) {
            LOGGER.debug("Token {} is not supported by {}.", token.getClass(), AbstractStsRealm.class.getName());
        } else {
            LOGGER.debug("The supplied authentication token is null. Sending back not supported.");
        }

        return supported;
    }

    protected abstract boolean shouldHandleWss();

    /**
     * Perform authentication based on the supplied token.
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) {
        String method = "doGetAuthenticationInfo(    AuthenticationToken token )";
        LOGGER.entry(method);

        Object credential;

        if (token instanceof SAMLAuthenticationToken) {
            credential = token.getCredentials();
        } else if (token instanceof BaseAuthenticationToken) {
            credential = ((BaseAuthenticationToken) token).getCredentialsAsXMLString();
        } else {
            credential = token.getCredentials().toString();
        }
        if (credential == null) {
            String msg = "Unable to authenticate credential.  A NULL credential was provided in the supplied authentication token. This may be due to an error with the SSO server that created the token.";
            LOGGER.error(msg);
            throw new AuthenticationException(msg);
        } else {
            //removed the credentials from the log message for now, I don't think we should be dumping user/pass into log
            LOGGER.debug("Received credentials.");
        }

        if (!settingsConfigured) {
            configureStsClient();
            settingsConfigured = true;
        } else {
            setClaimsOnStsClient(createClaimsElement());
        }

        SecurityToken securityToken;
        if (token instanceof SAMLAuthenticationToken && credential instanceof SecurityToken) {
            securityToken = renewSecurityToken((SecurityToken) credential);
        } else {
            securityToken = requestSecurityToken(credential);
        }

        LOGGER.debug("Creating token authentication information with SAML.");
        SimpleAuthenticationInfo simpleAuthenticationInfo = new SimpleAuthenticationInfo();
        SimplePrincipalCollection principals = new SimplePrincipalCollection();
        SecurityAssertion assertion = new SecurityAssertionImpl(securityToken);
        principals.add(assertion.getPrincipal(), NAME);
        principals.add(assertion, NAME);
        simpleAuthenticationInfo.setPrincipals(principals);
        simpleAuthenticationInfo.setCredentials(credential);

        LOGGER.exit(method);
        return simpleAuthenticationInfo;
    }

    /**
     * Request a security token (SAML assertion) from the STS.
     *
     * @param authToken The subject the security token is being request for.
     * @return security token (SAML assertion)
     */
    protected SecurityToken requestSecurityToken(Object authToken) {
        SecurityToken token = null;
        String stsAddress = getAddress();

        try {
            LOGGER.debug("Requesting security token from STS at: " + stsAddress + ".");

            if (authToken != null) {
                LOGGER.debug("Telling the STS to request a security token on behalf of the auth token");
                SecurityLogger.logInfo("Telling the STS to request a security token on behalf of the auth token");
                stsClient.setWsdlLocation(stsAddress);
                stsClient.setOnBehalfOf(authToken);
                stsClient.setTokenType(getAssertionType());
                stsClient.setKeyType(getKeyType());
                stsClient.setKeySize(Integer.parseInt(getKeySize()));
                token = stsClient.requestSecurityToken(stsAddress);
                LOGGER.debug("Finished requesting security token.");
                SecurityLogger.logInfo("Finished requesting security token.");

                SecurityLogger.logSecurityAssertionInfo(token);
            }
        } catch (Exception e) {
            String msg = "Error requesting the security token from STS at: " + stsAddress + ".";
            LOGGER.error(msg, e);
            SecurityLogger.logError(msg);
            throw new AuthenticationException(msg, e);
        }

        return token;
    }

    /**
     * Renew a security token (SAML assertion) from the STS.
     *
     * @param securityToken The token being renewed.
     * @return security token (SAML assertion)
     */
    protected SecurityToken renewSecurityToken(SecurityToken securityToken) {
        SecurityToken token = null;
        String stsAddress = getAddress();

        try {
            LOGGER.debug("Renewing security token from STS at: " + stsAddress + ".");

            if (securityToken != null) {
                LOGGER.debug("Telling the STS to renew a security token on behalf of the auth token");
                SecurityLogger.logInfo("Telling the STS to renew a security token on behalf of the auth token");
                stsClient.setWsdlLocation(stsAddress);
                stsClient.setTokenType(getAssertionType());
                stsClient.setKeyType(getKeyType());
                stsClient.setKeySize(Integer.parseInt(getKeySize()));
                stsClient.setAllowRenewing(true);
                token = stsClient.renewSecurityToken(securityToken);
                LOGGER.debug("Finished renewing security token.");
                SecurityLogger.logInfo("Finished renewing security token.");

                SecurityLogger.logSecurityAssertionInfo(token);
            }
        } catch (Exception e) {
            String msg = "Error renewing the security token from STS at: " + stsAddress + ".";
            LOGGER.error(msg, e);
            SecurityLogger.logError(msg);
            throw new AuthenticationException(msg, e);
        }

        return token;
    }

    /**
     * Logs the current STS client configuration.
     */
    private void logStsClientConfiguration() {
        StringBuilder builder = new StringBuilder();

        builder.append("\nSTS Client configuration:\n");
        builder.append("STS WSDL location: " + stsClient.getWsdlLocation() + "\n");
        builder.append("STS service name: " + stsClient.getServiceQName() + "\n");
        builder.append("STS endpoint name: " + stsClient.getEndpointQName() + "\n");

        Map<String, Object> map = stsClient.getProperties();
        Set<Map.Entry<String, Object>> entries = map.entrySet();
        builder.append("\nSTS Client properties:\n");
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            builder.append("key: " + entry.getKey() + "; value: " + entry.getValue() + "\n");
        }

        LOGGER.debug(builder.toString());
    }

    /**
     * Helper method to setup STS Client.
     */
    protected Bus getBus() {
        BusFactory bf = new CXFBusFactory();
        Bus setBus = bf.createBus();
        SpringBusFactory.setDefaultBus(setBus);
        SpringBusFactory.setThreadDefaultBus(setBus);

        return setBus;
    }

    /**
     * Helper method to setup STS Client.
     */
    private void addStsProperties() {
        Map<String, Object> map = new HashMap<>();

        String signaturePropertiesPath = getSignatureProperties();
        if (signaturePropertiesPath != null && !signaturePropertiesPath.isEmpty()) {
            LOGGER.debug("Setting signature properties on STSClient: " + signaturePropertiesPath);
            Properties signatureProperties = PropertiesLoader.loadProperties(signaturePropertiesPath);
            map.put(SecurityConstants.SIGNATURE_PROPERTIES, signatureProperties);
        }

        String encryptionPropertiesPath = getEncryptionProperties();
        if (encryptionPropertiesPath != null && !encryptionPropertiesPath.isEmpty()) {
            LOGGER.debug("Setting encryption properties on STSClient: " + encryptionPropertiesPath);
            Properties encryptionProperties = PropertiesLoader.loadProperties(encryptionPropertiesPath);
            map.put(SecurityConstants.ENCRYPT_PROPERTIES, encryptionProperties);
        }

        String stsPropertiesPath = getTokenProperties();
        if (stsPropertiesPath != null && !stsPropertiesPath.isEmpty()) {
            LOGGER.debug("Setting sts properties on STSClient: " + stsPropertiesPath);
            Properties stsProperties = PropertiesLoader.loadProperties(stsPropertiesPath);
            map.put(SecurityConstants.STS_TOKEN_PROPERTIES, stsProperties);
        }

        LOGGER.debug("Setting callback handler on STSClient");
        //DDF-733 map.put(SecurityConstants.CALLBACK_HANDLER, new CommonCallbackHandler());

        LOGGER.debug("Setting STS TOKEN USE CERT FOR KEY INFO to \"true\"");
        map.put(SecurityConstants.STS_TOKEN_USE_CERT_FOR_KEYINFO, String.valueOf(getUseKey()));

        LOGGER.debug("Adding in realm information to the STSClient");
        map.put("CLIENT_REALM", "DDF");

        stsClient.setProperties(map);
    }

    /**
     * Helper method to setup STS Client.
     */
    private void configureBaseStsClient() {
        stsClient = new STSClient(bus);
        String stsAddress = getAddress();
        String stsServiceName = getServiceName();
        String stsEndpointName = getEndpointName();

        if (stsAddress != null) {
            LOGGER.debug("Setting WSDL location on STSClient: " + stsAddress);
            stsClient.setWsdlLocation(stsAddress);
        }

        if (stsServiceName != null) {
            LOGGER.debug("Setting service name on STSClient: " + stsServiceName);
            stsClient.setServiceName(stsServiceName);
        }

        if (stsEndpointName != null) {
            LOGGER.debug("Setting endpoint name on STSClient: " + stsEndpointName);
            stsClient.setEndpointName(stsEndpointName);
        }

        LOGGER.debug("Setting addressing namespace on STSClient: " + ADDRESSING_NAMESPACE);
        stsClient.setAddressingNamespace(ADDRESSING_NAMESPACE);
    }

    /**
     * Helper method to setup STS Client.
     */
    protected void configureStsClient() {
        LOGGER.debug("Configuring the STS client.");

        configureBaseStsClient();

        addStsProperties();

        setClaimsOnStsClient(createClaimsElement());

        if (LOGGER.isDebugEnabled()) {
            logStsClientConfiguration();
        }
    }

    /**
     * Set the claims on the sts client.
     */
    private void setClaimsOnStsClient(Element claimsElement) {
        if (claimsElement != null) {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug(" Setting STS claims to:\n" + this.getFormattedXml(claimsElement));
            }

            stsClient.setClaims(claimsElement);
        }
    }

    /**
     * Create the claims element with the claims provided in the STS client configuration in the
     * admin console.
     */
    protected Element createClaimsElement() {
        Element claimsElement = null;
        List<String> claims = new ArrayList<>();
        claims.addAll(getClaims());

        if (contextPolicyManager != null) {
            Collection<ContextPolicy> contextPolicies = contextPolicyManager.getAllContextPolicies();
            Set<String> attributes = new LinkedHashSet<>();
            if (contextPolicies != null && contextPolicies.size() > 0) {
                for (ContextPolicy contextPolicy : contextPolicies) {
                    attributes.addAll(contextPolicy.getAllowedAttributeNames());
                }
            }

            if (attributes.size() > 0) {
                claims.addAll(attributes);
            }
        }

        if (claims.size() != 0) {
            W3CDOMStreamWriter writer = null;

            try {
                writer = new W3CDOMStreamWriter();

                writer.writeStartElement("wst", "Claims", STSUtils.WST_NS_05_12);
                writer.writeNamespace("wst", STSUtils.WST_NS_05_12);
                writer.writeNamespace("ic", "http://schemas.xmlsoap.org/ws/2005/05/identity");
                writer.writeAttribute("Dialect", "http://schemas.xmlsoap.org/ws/2005/05/identity");

                for (String claim : claims) {
                    LOGGER.trace("Claim: " + claim);
                    writer.writeStartElement("ic", "ClaimType", "http://schemas.xmlsoap.org/ws/2005/05/identity");
                    writer.writeAttribute("Uri", claim);
                    writer.writeAttribute("Optional", "true");
                    writer.writeEndElement();
                }

                writer.writeEndElement();

                claimsElement = writer.getDocument().getDocumentElement();
            } catch (XMLStreamException e) {
                String msg = "Unable to create claims.";
                LOGGER.error(msg, e);
                claimsElement = null;
            } finally {
                if (writer != null) {
                    try {
                        writer.close();
                    } catch (XMLStreamException ignore) {
                        //ignore
                    }
                }
            }

            if (LOGGER.isDebugEnabled()) {
                if (claimsElement != null) {
                    LOGGER.debug("\nClaims:\n" + getFormattedXml(claimsElement));
                }
            }
        } else {
            LOGGER.debug("There are no claims to process.");
            claimsElement = null;
        }

        return claimsElement;
    }

    /**
     * Transform into formatted XML.
     */
    private String getFormattedXml(Node node) {
        Document document = node.getOwnerDocument().getImplementation().createDocument("", "fake", null);
        Element copy = (Element) document.importNode(node, true);
        document.importNode(node, false);
        document.removeChild(document.getDocumentElement());
        document.appendChild(copy);
        DOMImplementation domImpl = document.getImplementation();
        DOMImplementationLS domImplLs = (DOMImplementationLS) domImpl.getFeature("LS", "3.0");
        if (null != domImplLs) {
            LSSerializer serializer = domImplLs.createLSSerializer();
            serializer.getDomConfig().setParameter("format-pretty-print", true);
            return serializer.writeToString(document);
        } else {
            return "";
        }
    }

    @Override
    public String getAddress() {
        return address.getResolvedString();
    }

    @Override
    public void setAddress(String address) {
        this.address = new PropertyResolver(address);
    }

    @Override
    public String getEndpointName() {
        return endpointName;
    }

    @Override
    public void setEndpointName(String endpointName) {
        this.endpointName = endpointName;
    }

    @Override
    public String getServiceName() {
        return serviceName;
    }

    @Override
    public void setServiceName(String serviceName) {
        this.serviceName = serviceName;
    }

    @Override
    public String getUsername() {
        return username;
    }

    @Override
    public void setUsername(String username) {
        this.username = username;
    }

    @Override
    public String getPassword() {
        return password;
    }

    @Override
    public void setPassword(String password) {
        this.password = password;
    }

    @Override
    public String getSignatureUsername() {
        return signatureUsername;
    }

    @Override
    public void setSignatureUsername(String signatureUsername) {
        this.signatureUsername = signatureUsername;
    }

    @Override
    public String getSignatureProperties() {
        return signatureProperties;
    }

    @Override
    public void setSignatureProperties(String signatureProperties) {
        this.signatureProperties = signatureProperties;

    }

    @Override
    public String getEncryptionUsername() {
        return encryptionUsername;
    }

    @Override
    public void setEncryptionUsername(String encryptionUsername) {
        this.encryptionUsername = encryptionUsername;

    }

    @Override
    public String getEncryptionProperties() {
        return encryptionProperties;
    }

    @Override
    public void setEncryptionProperties(String encryptionProperties) {
        this.encryptionProperties = encryptionProperties;

    }

    @Override
    public String getTokenUsername() {
        return tokenUsername;
    }

    @Override
    public void setTokenUsername(String tokenUsername) {
        this.tokenUsername = tokenUsername;
    }

    @Override
    public String getTokenProperties() {
        return tokenProperties;
    }

    @Override
    public void setTokenProperties(String tokenProperties) {
        this.tokenProperties = tokenProperties;
    }

    @Override
    public List<String> getClaims() {
        return claims;
    }

    @Override
    public void setClaims(List<String> claims) {
        this.claims = Collections.unmodifiableList(claims);
    }

    @Override
    public void setClaims(String claimsListAsString) {

        setClaims(SPLITTER.splitToList(claimsListAsString));
    }

    @Override
    public String getAssertionType() {
        return assertionType;
    }

    @Override
    public void setAssertionType(String assertionType) {
        this.assertionType = assertionType;
    }

    @Override
    public String getKeyType() {
        return keyType;
    }

    @Override
    public void setKeyType(String keyType) {
        this.keyType = keyType;
    }

    @Override
    public String getKeySize() {
        return keySize;
    }

    @Override
    public void setKeySize(String keySize) {
        this.keySize = keySize;
    }

    @Override
    public Boolean getUseKey() {
        return useKey;
    }

    @Override
    public void setUseKey(Boolean useKey) {
        this.useKey = useKey;
    }

    /**
     * Credentials matcher class that ensures the AuthInfo received from the STS matches the
     * AuthToken
     */
    protected static class STSCredentialsMatcher implements CredentialsMatcher {

        @Override
        public boolean doCredentialsMatch(AuthenticationToken token, AuthenticationInfo info) {
            if (token instanceof SAMLAuthenticationToken) {
                SecurityToken oldToken = (SecurityToken) token.getCredentials();
                SecurityToken newToken = (SecurityToken) info.getCredentials();
                return oldToken.getId().equals(newToken.getId());
            } else if (token instanceof BaseAuthenticationToken) {
                String xmlCreds = ((BaseAuthenticationToken) token).getCredentialsAsXMLString();
                if (xmlCreds != null && info.getCredentials() != null) {
                    return xmlCreds.equals(info.getCredentials());
                }
            } else {
                if (token.getCredentials() != null && info.getCredentials() != null) {
                    return token.getCredentials().equals(info.getCredentials());
                }
            }
            return false;
        }
    }
}