org.atricore.idbus.capabilities.sso.main.binding.SamlR2HttpRedirectBinding.java Source code

Java tutorial

Introduction

Here is the source code for org.atricore.idbus.capabilities.sso.main.binding.SamlR2HttpRedirectBinding.java

Source

/*
 * Atricore IDBus
 *
 * Copyright (c) 2009, Atricore Inc.
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */

package org.atricore.idbus.capabilities.sso.main.binding;

import oasis.names.tc.saml._2_0.protocol.RequestAbstractType;
import oasis.names.tc.saml._2_0.protocol.StatusResponseType;
import org.apache.camel.Exchange;
import org.apache.camel.Message;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.atricore.idbus.capabilities.sso.support.binding.SSOBinding;
import org.atricore.idbus.capabilities.sso.support.core.signature.SamlR2Signer;
import org.atricore.idbus.capabilities.sso.support.core.util.XmlUtils;
import org.atricore.idbus.kernel.main.federation.metadata.EndpointDescriptor;
import org.atricore.idbus.kernel.main.mediation.Channel;
import org.atricore.idbus.kernel.main.mediation.MediationMessage;
import org.atricore.idbus.kernel.main.mediation.MediationMessageImpl;
import org.atricore.idbus.kernel.main.mediation.MediationState;
import org.atricore.idbus.kernel.main.mediation.camel.component.binding.AbstractMediationHttpBinding;
import org.atricore.idbus.kernel.main.mediation.camel.component.binding.CamelMediationMessage;
import org.w3._1999.xhtml.Html;

import java.io.*;
import java.net.URLEncoder;
import java.util.zip.*;

/**
 * @author <a href="mailto:sgonzalez@atricore.org">Sebastian Gonzalez Oyuela</a>
 * @version $Id$
 */
public class SamlR2HttpRedirectBinding extends AbstractMediationHttpBinding {

    private static final Log logger = LogFactory.getLog(SamlR2HttpRedirectBinding.class);

    public SamlR2HttpRedirectBinding(Channel channel) {
        super(SSOBinding.SAMLR2_REDIRECT.getValue(), channel);
    }

    public MediationMessage createMessage(CamelMediationMessage message) {

        // The nested exchange contains HTTP information
        Exchange exchange = message.getExchange().getExchange();
        logger.debug("Create Message Body from exchange " + exchange.getClass().getName());

        Message httpMsg = exchange.getIn();

        if (httpMsg.getHeader("http.requestMethod") == null
                || !httpMsg.getHeader("http.requestMethod").equals("GET")) {
            throw new IllegalArgumentException("Unknown message, no valid HTTP Method header found!");
        }

        try {

            // HTTP Request Parameters from HTTP Request body
            MediationState state = createMediationState(exchange);

            // HTTP Redirect SSOBinding supports the following parameters
            String base64SAMLRequest = state.getTransientVariable("SAMLRequest");
            String base64SAMLResponse = state.getTransientVariable("SAMLResponse");
            String relayState = state.getTransientVariable("RelayState");

            // Follow SAML 2.0 Binding, section 3.4.4.1 DEFLATE Encoding , regarding Digital Siganture usage.
            String sigAlg = state.getTransientVariable("SigAlg"); // Use HTTP Redirect binding Signature Algorithm
            String signature = state.getTransientVariable("Signature"); // Validate HTTP Redirect binding Signature

            if (base64SAMLRequest != null && base64SAMLResponse != null) {
                throw new IllegalStateException("Received both SAML Request and SAML Response");
            }

            if (base64SAMLRequest == null && base64SAMLResponse == null) {
                throw new IllegalStateException("Recevied no SAML Request nor SAML Response");
            }

            if (base64SAMLRequest != null) {

                // By default, we use inflate/deflate
                base64SAMLRequest = inflateFromRedirect(base64SAMLRequest, true);

                if (logger.isDebugEnabled())
                    logger.debug("Received SAML 2.0 Request [" + base64SAMLRequest + "]");

                // SAML Request
                RequestAbstractType samlRequest = XmlUtils.unmarshalSamlR2Request(base64SAMLRequest, false);

                if (logger.isDebugEnabled())
                    logger.debug("Received SAML 2.0 Request " + samlRequest.getID());

                // Store relay state to send it back later
                if (relayState != null) {
                    state.setLocalVariable(
                            "urn:org:atricore:idbus:samr2:protocol:relayState:" + samlRequest.getID(), relayState);
                }

                return new MediationMessageImpl<RequestAbstractType>(httpMsg.getMessageId(), samlRequest,
                        base64SAMLRequest, null, relayState, null, state);
            } else {

                // SAML Response
                base64SAMLResponse = inflateFromRedirect(base64SAMLResponse, true);

                if (logger.isDebugEnabled())
                    logger.debug("Received SAML 2.0 Response [" + base64SAMLResponse + "]");

                StatusResponseType samlResponse = XmlUtils.unmarshalSamlR2Response(base64SAMLResponse, false);

                if (logger.isDebugEnabled())
                    logger.debug("Received SAML 2.0 Response " + samlResponse.getID());

                return new MediationMessageImpl<StatusResponseType>(httpMsg.getMessageId(), samlResponse,
                        base64SAMLResponse, null, relayState, null, state);
            }

        } catch (Exception e) {
            throw new RuntimeException(e);
        }

    }

    public void copyMessageToExchange(CamelMediationMessage samlOut, Exchange exchange) {
        try {

            MediationMessage out = samlOut.getMessage();
            EndpointDescriptor ed = out.getDestination();

            // ------------------------------------------------------------
            // Validate received message
            // ------------------------------------------------------------
            assert ed != null : "Mediation Response MUST Provide a destination";
            if (out.getContent() == null) {
                throw new NullPointerException(
                        "Cannot Create form with null content for action " + ed.getLocation());
            }
            String location = ed.getResponseLocation() != null ? ed.getResponseLocation() : ed.getLocation();

            // ------------------------------------------------------------
            // Create HTML Form for response body
            // ------------------------------------------------------------

            if (logger.isDebugEnabled())
                logger.debug("Creating HTML Redirect to " + location);

            String msgName = null;
            java.lang.Object msgValue = null;
            String element = out.getContentType();
            boolean isResponse = false;
            String relayState = out.getRelayState();

            boolean signMsg = false;

            if (out.getContent() instanceof RequestAbstractType) {

                msgName = "SAMLRequest";

                RequestAbstractType req = (RequestAbstractType) out.getContent();

                if (req.getSignature() != null) {
                    // Strip DS information from request/response!
                    req.setSignature(null);
                    signMsg = true;
                }

                // Marshall
                String s = XmlUtils.marshalSamlR2Request((RequestAbstractType) out.getContent(), element, false);

                // Use default DEFLATE (rfc 1951)
                msgValue = deflateForRedirect(s, true);
                msgValue = URLEncoder.encode((String) msgValue, "UTF-8");

            } else if (out.getContent() instanceof StatusResponseType) {
                msgName = "SAMLResponse";
                isResponse = true;

                // Strip DS information from request/response!
                StatusResponseType res = (StatusResponseType) out.getContent();
                if (res.getSignature() != null) {
                    res.setSignature(null);
                    signMsg = true;
                }

                // Marshall
                String s = XmlUtils.marshalSamlR2Response((StatusResponseType) out.getContent(), element, false);

                // Use default DEFLATE (rfc 1951)
                msgValue = deflateForRedirect(s, true);
                msgValue = URLEncoder.encode((String) msgValue, "UTF-8");

                StatusResponseType samlResponse = (StatusResponseType) out.getContent();
                if (samlResponse.getInResponseTo() != null) {
                    String rs = (String) out.getState().getLocalVariable(
                            "urn:org:atricore:idbus:samr2:protocol:relayState:" + samlResponse.getInResponseTo());
                    if (relayState != null && rs != null && !relayState.equals(rs)) {
                        relayState = rs;
                        logger.warn("Provided relay state does not match stored state : " + relayState + " : " + rs
                                + ", forcing " + relayState);
                    }
                }

            }

            if (out.getContent() == null) {
                throw new NullPointerException("Cannot REDIRECT null content to " + location);
            }
            if (!(msgValue instanceof String)) {
                throw new IllegalArgumentException(
                        "Cannot REDIRECT content of type " + msgValue.getClass().getName());
            }

            String qryString = msgName + "=" + (String) msgValue;
            if (out.getRelayState() != null) {
                qryString += "&RelayState=" + relayState;
            }

            // Follow SAML 2.0 Binding, section 3.4.4.1 DEFLATE Encoding , regarding Digital Siganture usage.
            // Generate HTTP Redirect binding SigAlg and Signature parameters (this requires access to Provider signer!)
            MediationState state = samlOut.getMessage().getState();
            SamlR2Signer signer = (SamlR2Signer) state.getAttribute("SAMLR2Signer");
            if (signer != null) {
                qryString = "?" + signer.signQueryString(qryString);
            } else {
                qryString = "?" + qryString;
            }

            Message httpOut = exchange.getOut();
            Message httpIn = exchange.getIn();
            String redirLocation = this.buildHttpTargetLocation(httpIn, ed, isResponse) + qryString;

            // ------------------------------------------------------------
            // Prepare HTTP Resposne
            // ------------------------------------------------------------
            copyBackState(out.getState(), exchange);

            httpOut.getHeaders().put("Cache-Control", "no-cache, no-store");
            httpOut.getHeaders().put("Pragma", "no-cache");
            httpOut.getHeaders().put("http.responseCode", 302);
            httpOut.getHeaders().put("Content-Type", "text/html");
            httpOut.getHeaders().put("Location", redirLocation);
            handleCrossOriginResourceSharing(exchange);

        } catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    public static String deflateForRedirect(String redirStr, boolean encode) {

        int n = redirStr.length();
        byte[] redirIs = null;
        try {
            redirIs = redirStr.getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }

        byte[] deflated = new byte[n];

        Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
        deflater.setInput(redirIs);
        deflater.finish();
        int len = deflater.deflate(deflated);
        deflater.end();

        byte[] exact = new byte[len];

        System.arraycopy(deflated, 0, exact, 0, len);

        if (encode) {
            byte[] base64Str = new Base64().encode(exact);
            return new String(base64Str);
        }

        return new String(exact);
    }

    public static String inflateFromRedirect(String redirStr, boolean decode) throws Exception {

        if (redirStr == null || redirStr.length() == 0) {
            throw new RuntimeException("Redirect string cannot be null or empty");
        }

        byte[] redirBin = null;
        if (decode)
            redirBin = new Base64().decode(removeNewLineChars(redirStr).getBytes());
        else
            redirBin = redirStr.getBytes();

        // Decompress the bytes
        Inflater inflater = new Inflater(true);
        inflater.setInput(redirBin);

        ByteArrayOutputStream baos = new ByteArrayOutputStream(8192);

        try {
            int resultLength = 0;
            int buffSize = 1024;
            byte[] buff = new byte[buffSize];
            while (!inflater.finished()) {
                resultLength = inflater.inflate(buff);
                baos.write(buff, 0, resultLength);
            }

        } catch (DataFormatException e) {
            throw new RuntimeException("Cannot inflate SAML message : " + e.getMessage(), e);
        }

        inflater.end();

        // Decode the bytes into a String
        String outputString = null;
        try {
            outputString = new String(baos.toByteArray(), "UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException("Cannot convert byte array to string " + e.getMessage(), e);
        }
        return outputString;
    }

    public static String removeNewLineChars(String s) {
        String retString = null;
        if ((s != null) && (s.length() > 0) && (s.indexOf('\n') != -1)) {
            char[] chars = s.toCharArray();
            int len = chars.length;
            StringBuffer sb = new StringBuffer(len);
            for (int i = 0; i < len; i++) {
                char c = chars[i];
                if (c != '\n') {
                    sb.append(c);
                }
            }
            retString = sb.toString();
        } else {
            retString = s;
        }
        return retString;
    }

}