org.asimba.wa.integrationtest.saml2.model.AuthnRequest.java Source code

Java tutorial

Introduction

Here is the source code for org.asimba.wa.integrationtest.saml2.model.AuthnRequest.java

Source

/*
 * Asimba - Serious Open Source SSO
 * 
 * Copyright (C) 2014 Asimba
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see www.gnu.org/licenses
 * 
 * Asimba - Serious Open Source SSO - More information on www.asimba.org
 * 
 */
package org.asimba.wa.integrationtest.saml2.model;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.KeyException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.KeyStore.Entry;
import java.security.KeyStore.PasswordProtection;
import java.security.KeyStore.PrivateKeyEntry;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableEntryException;
import java.security.cert.CertificateException;
import java.text.SimpleDateFormat;
import java.util.Collections;
import java.util.Date;
import java.util.TimeZone;
import java.util.UUID;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.Inflater;
import java.util.zip.InflaterInputStream;

import javax.xml.crypto.MarshalException;
import javax.xml.crypto.dsig.CanonicalizationMethod;
import javax.xml.crypto.dsig.DigestMethod;
import javax.xml.crypto.dsig.Reference;
import javax.xml.crypto.dsig.SignatureMethod;
import javax.xml.crypto.dsig.SignedInfo;
import javax.xml.crypto.dsig.Transform;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.crypto.dsig.XMLSignatureException;
import javax.xml.crypto.dsig.XMLSignatureFactory;
import javax.xml.crypto.dsig.dom.DOMSignContext;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyInfoFactory;
import javax.xml.crypto.dsig.keyinfo.KeyValue;
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
import javax.xml.crypto.dsig.spec.TransformParameterSpec;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.XMLOutputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamWriter;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerConfigurationException;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.apache.commons.codec.binary.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

public class AuthnRequest extends SAMLMessage {

    private static final Logger _logger = LoggerFactory.getLogger(AuthnRequest.class);

    /** the (incoming) XML document of an AuthnRequest */
    private Document _authnRequestDocument;

    private String _id;
    private String _issueInstant;

    private String _acsUrl;

    private String _requestedAuthnContext = null;
    /** Default Requested NameID@Format: */
    private String _requestedNameIdFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent";

    public AuthnRequest(String acsUrl, String issuer) {
        _id = "_" + UUID.randomUUID().toString();
        SimpleDateFormat simpleDf = new SimpleDateFormat("yyyy-MM-dd'T'H:mm:ss");
        simpleDf.setTimeZone(TimeZone.getTimeZone("GMT"));
        _issueInstant = simpleDf.format(new Date());

        _acsUrl = acsUrl;
        _issuer = issuer;
    }

    /**
     * Get String with the SAML2 AuthnRequest message
     * @param format -1=plain, 1=base64
     * @return
     * @throws XMLStreamException
     * @throws IOException
     */
    public String getRequest(int format) throws XMLStreamException, IOException {
        _logger.info("For ID: " + this._id);
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        Deflater compresser = new Deflater(Deflater.BEST_COMPRESSION, true);
        DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(baos, compresser);
        StringWriter sw = new StringWriter();

        XMLOutputFactory factory = XMLOutputFactory.newInstance();
        XMLStreamWriter writer = null;

        // ugly but effective:
        if (format == base64) {
            writer = factory.createXMLStreamWriter(deflaterOutputStream);
        } else {
            writer = factory.createXMLStreamWriter(sw);
        }

        writer.writeStartElement("samlp", "AuthnRequest", "urn:oasis:names:tc:SAML:2.0:protocol");
        writer.writeNamespace("samlp", "urn:oasis:names:tc:SAML:2.0:protocol");

        writer.writeAttribute("ID", _id);
        writer.writeAttribute("Version", "2.0");
        writer.writeAttribute("IssueInstant", this._issueInstant);
        writer.writeAttribute("ProtocolBinding", "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST");
        writer.writeAttribute("AssertionConsumerServiceURL", _acsUrl);

        writeIssuer(writer);

        writeNameIDPolicy(writer);

        writeRequestedAuthnContext(writer);

        writer.writeEndElement();
        writer.flush();

        if (format == base64) {
            deflaterOutputStream.close();
            byte[] bain = baos.toByteArray();
            byte[] encoded = Base64.encodeBase64(bain, false);
            String result = new String(encoded, Charset.forName("UTF-8"));

            return result;
        } else {
            return sw.toString();
        }

    }

    public String getId() {
        String xpathQuery = "/saml2p:AuthnRequest/@ID";
        return executeXPathValueQuery(xpathQuery, _authnRequestDocument);
    }

    public String getACSURL() {
        String xpathQuery = "/saml2p:AuthnRequest/@AssertionConsumerServiceURL";
        return executeXPathValueQuery(xpathQuery, _authnRequestDocument);
    }

    public String getProtocolBinding() {
        String xpathQuery = "/saml2p:AuthnRequest/@ProtocolBinding";
        return executeXPathValueQuery(xpathQuery, _authnRequestDocument);
    }

    protected void writeIssuer(XMLStreamWriter writer) throws XMLStreamException {
        writer.writeStartElement("saml", "Issuer", "urn:oasis:names:tc:SAML:2.0:assertion");
        writer.writeNamespace("saml", "urn:oasis:names:tc:SAML:2.0:assertion");
        writer.writeCharacters(_issuer);
        writer.writeEndElement();
    }

    /**
     * Get Issuer from the AuthnRequest document
     * @return
     */
    public String getIssuer() {
        String xpathQuery = "/saml2p:AuthnRequest/saml2:Issuer";
        return executeXPathValueQuery(xpathQuery, _authnRequestDocument);
    }

    protected void writeNameIDPolicy(XMLStreamWriter writer) throws XMLStreamException {
        if (_requestedNameIdFormat == null) {
            _logger.info("Skipping NameIDPolicy in request");
            return;
        }

        _logger.info("Adding {} as NameIDPolicy@Format", _requestedNameIdFormat);

        writer.writeStartElement("samlp", "NameIDPolicy", "urn:oasis:names:tc:SAML:2.0:protocol");
        writer.writeAttribute("Format", _requestedNameIdFormat);
        writer.writeAttribute("AllowCreate", "true");
        writer.writeEndElement();
    }

    public String getNameIDPolicy() {
        String xpathQuery = "/saml2p:AuthnRequest/saml2p:NameIDPolicy/@Format";
        return executeXPathValueQuery(xpathQuery, _authnRequestDocument);
    }

    protected void writeRequestedAuthnContext(XMLStreamWriter writer) throws XMLStreamException {
        if (_requestedAuthnContext == null) {
            _logger.info("Skipping RequestedAuthnContext in request");
            return;
        }

        _logger.info("Adding {} as RequestedAuthnContext@AuthnContextClassRef", _requestedAuthnContext);

        writer.writeStartElement("samlp", "RequestedAuthnContext", "urn:oasis:names:tc:SAML:2.0:protocol");

        writer.writeAttribute("Comparison", "exact");

        writer.writeStartElement("saml", "AuthnContextClassRef", "urn:oasis:names:tc:SAML:2.0:assertion");
        writer.writeNamespace("saml", "urn:oasis:names:tc:SAML:2.0:assertion");
        writer.writeCharacters(_requestedAuthnContext);
        writer.writeEndElement();

        writer.writeEndElement();
    }

    public static String getRidOfCRLF(String what) {
        String lf = "%0D";
        String cr = "%0A";
        String now = lf;

        int index = what.indexOf(now);
        StringBuffer r = new StringBuffer();

        while (index != -1) {
            r.append(what.substring(0, index));
            what = what.substring(index + 3, what.length());

            if (now.equals(lf)) {
                now = cr;
            } else {
                now = lf;
            }

            index = what.indexOf(now);
        }
        return r.toString();
    }

    /** ensures that the ID can be found as ID-attribute */
    private void tagIdAttributes(Document xmlDoc) {
        NodeList nodeList = xmlDoc.getElementsByTagName("*");
        for (int i = 0; i < nodeList.getLength(); i++) {
            Node node = nodeList.item(i);
            if (node.getNodeType() == Node.ELEMENT_NODE) {
                if (node.getAttributes().getNamedItem("ID") != null) {
                    ((Element) node).setIdAttribute("ID", true);
                }
            }
        }
    }

    /**
     * Open keystore
     * @param type probably "JKS"
     * @param keystoreStream
     * @param password
     * @return
     */
    private KeyStore getKeystore(String type, InputStream keystoreStream, String password) {
        KeyStore keystore;
        try {
            keystore = KeyStore.getInstance(type);
            keystore.load(keystoreStream, password.toCharArray());

            return keystore;
        } catch (NoSuchAlgorithmException | CertificateException | KeyStoreException | IOException e) {
            // Could not load keystore
            e.printStackTrace();
            return null;
        }
    }

    private KeyPair getKeyPairFromKeystore(KeyStore keystore, String keyAlias, String keyPassword) {
        try {
            PasswordProtection passwordProtected = new PasswordProtection(keyPassword.toCharArray());
            Entry keyEntry = keystore.getEntry(keyAlias, passwordProtected);

            if (!(keyEntry instanceof PrivateKeyEntry)) {
                // Invalid key entry
                return null;
            }
            PrivateKeyEntry pkEntry = (PrivateKeyEntry) keyEntry;
            return new KeyPair(pkEntry.getCertificate().getPublicKey(), pkEntry.getPrivateKey());

        } catch (KeyStoreException | NoSuchAlgorithmException | UnrecoverableEntryException e) {
            // Problem occurred
            e.printStackTrace();
            return null;
        }
    }

    // Get signed XML document
    public String getSignedRequest(int format, InputStream keystoreStream, String keystorePassword, String keyAlias,
            String keyPassword) {
        DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
        dbf.setNamespaceAware(true);

        DocumentBuilder builder;
        Document doc;
        try {
            builder = dbf.newDocumentBuilder();
            doc = builder.parse(new InputSource(new ByteArrayInputStream(getRequest(plain).getBytes("utf-8"))));

            // Prepare doc by marking attributes as referenceable:
            tagIdAttributes(doc);

            // Prepare cryptographic environemnt
            KeyStore keystore = getKeystore("JKS", keystoreStream, keystorePassword);
            if (keystore == null)
                return null;

            KeyPair kp;

            kp = getKeyPairFromKeystore(keystore, keyAlias, keyPassword);
            if (kp == null) {
                // Generate key, to prove that it works...
                KeyPairGenerator kpg = KeyPairGenerator.getInstance("DSA");
                kpg.initialize(512);
                kp = kpg.generateKeyPair();
            }

            // Set signing context with PrivateKey and root of the Document
            DOMSignContext dsc = new DOMSignContext(kp.getPrivate(), doc.getDocumentElement());

            // Get SignatureFactory for creating signatures in DOM:
            XMLSignatureFactory fac = XMLSignatureFactory.getInstance("DOM");

            // Create reference for "" -> root of the document
            // SAML requires enveloped transform
            Reference ref = fac.newReference("#" + this._id, fac.newDigestMethod(DigestMethod.SHA1, null),
                    Collections.singletonList(fac.newTransform(Transform.ENVELOPED, (TransformParameterSpec) null)),
                    null, null);

            // Create SignedInfo (SAML2: Exclusive with or without comments is specified)
            SignedInfo si = fac.newSignedInfo(
                    fac.newCanonicalizationMethod(CanonicalizationMethod.EXCLUSIVE_WITH_COMMENTS,
                            (C14NMethodParameterSpec) null),
                    fac.newSignatureMethod(SignatureMethod.DSA_SHA1, null), Collections.singletonList(ref));

            // Add KeyInfo to the document:
            KeyInfoFactory kif = fac.getKeyInfoFactory();

            // .. get key from the generated keypair:
            KeyValue kv = kif.newKeyValue(kp.getPublic());
            KeyInfo ki = kif.newKeyInfo(Collections.singletonList(kv));

            XMLSignature signature = fac.newXMLSignature(si, ki);

            String before = docToString(doc);

            // Sign!
            signature.sign(dsc);

            _authnRequestDocument = doc; // persist, as we've worked hard for it

            String after = docToString(doc);

            if (_logger.isDebugEnabled()) {
                _logger.debug("Before: {}", before);
                _logger.debug("After : {}", after);
            }

            return after;

        } catch (ParserConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (SAXException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (XMLStreamException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (NoSuchAlgorithmException e) {
            // key generation exception
            e.printStackTrace();
        } catch (InvalidAlgorithmParameterException e) {
            // digest algorithm selection exception
            e.printStackTrace();
        } catch (KeyException e) {
            // when key-value was not available (when adding to KeyInfo)
            e.printStackTrace();
        } catch (MarshalException e) {
            // sign didn't work:
            e.printStackTrace();
        } catch (XMLSignatureException e) {
            // sign didn't work:
            e.printStackTrace();
        }
        return null;
    }

    private static String docToString(Document doc) {
        String result = null;

        try {

            // Build serializable metadata
            TransformerFactory transformerFactory = TransformerFactory.newInstance();
            Transformer serializer;
            serializer = transformerFactory.newTransformer();

            serializer.setOutputProperty(OutputKeys.ENCODING, "UTF-8");
            serializer.setOutputProperty(OutputKeys.INDENT, "yes");

            StringWriter stringWriter = new StringWriter();

            serializer.transform(new DOMSource(doc), new StreamResult(stringWriter));

            result = stringWriter.toString();

            return result;

        } catch (TransformerConfigurationException e) {
            // _logger.error("Exception when getting transformer for document transformation: "+e.getMessage());
            return null;
        } catch (TransformerException e) {
            // _logger.error("Exception when transforming document to DOM: "+e.getMessage());
            return null;
        }
    }

    /**
     * Parse an incoming AuthnRequest message from HTTP Redirect bound request;<br/> 
     * also takes care of base64-decoding and deflating the message
     * @param authnRequestString
     * @throws IOException 
     */
    public static AuthnRequest loadAuthnRequest(String authnRequestString) throws IOException {
        // Decode incoming request
        _logger.debug("B64-encoded string: {}", authnRequestString);
        byte[] decoded = Base64.decodeBase64(authnRequestString);

        // Now inflate it:
        ByteArrayInputStream bais = new ByteArrayInputStream(decoded);
        Inflater inflater = new Inflater(true); // require specific Inflater because header is omitted
        InflaterInputStream inflaterInputStream = new InflaterInputStream(bais, inflater);

        DocumentBuilderFactory builderFactory = DocumentBuilderFactory.newInstance();
        builderFactory.setNamespaceAware(true);

        DocumentBuilder documentBuilder;
        Document document = null;
        try {
            documentBuilder = builderFactory.newDocumentBuilder();
            document = documentBuilder.parse(inflaterInputStream);

            if (_logger.isDebugEnabled()) {
                String s = docToString(document);
                _logger.debug("Incoming AuthnRequest:\n{}", s);
            }

            AuthnRequest authnRequestResult = new AuthnRequest(null, null);
            authnRequestResult._authnRequestDocument = document;

            return authnRequestResult;

        } catch (ParserConfigurationException | SAXException | IOException e) {
            _logger.error("Could not get AuthnRequest document from parameter: {}", e.getMessage(), e);
            return null;
        }
    }

    public Document getAuthnRequestDocument() {
        return _authnRequestDocument;
    }

}