com.vmware.bdd.security.tls.SimpleSeverTrustTlsSocketFactory.java Source code

Java tutorial

Introduction

Here is the source code for com.vmware.bdd.security.tls.SimpleSeverTrustTlsSocketFactory.java

Source

/******************************************************************************
 *   Copyright (c) 2014 VMware, Inc. All Rights Reserved.
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 *****************************************************************************/
package com.vmware.bdd.security.tls;

import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;

import org.apache.commons.lang.ArrayUtils;

import com.vmware.bdd.utils.CommonUtil;

/**
 * this ssl socket factory is integrated with SimpleServerTrustFactory for graceful server certificate validation.
 */
public class SimpleSeverTrustTlsSocketFactory extends SSLSocketFactory {

    private final SSLSocketFactory defaultSSLSocketFactory;
    private SSLParameters sslParams;

    /**
     * Wrap a socket to enable custom configuration(ciphers and protocols) to be
     * supported for the connection
     *
     * @param sock a socket created by the
     *             {@link SSLSocketFactory#createSocket() method}
     * @return a wrapped socket which has the client specified configuration
     */
    private Socket wrapSocket(Socket sock) {
        SSLSocket sslSock = (SSLSocket) sock;
        sslSock.setSSLParameters(sslParams);
        try {
            sslSock.setSoTimeout(30000);
        } catch (SocketException e) {
            //
        }
        return sslSock;
    }

    /**
     * Wrap an existing factory
     *
     * @param factory   the SSLSocketFactory as returned by
     *                  {@link javax.net.ssl.SSLContext#getSocketFactory()}
     * @param sslParams The configuration to be set on the SSLSocket that is
     *                  created by the socket factory
     */
    public SimpleSeverTrustTlsSocketFactory(SSLSocketFactory factory, final SSLParameters sslParams) {
        this.defaultSSLSocketFactory = factory;
        this.sslParams = sslParams;
    }

    @Override
    public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
        return wrapSocket(defaultSSLSocketFactory.createSocket(s, host, port, autoClose));
    }

    @Override
    public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
        return wrapSocket(defaultSSLSocketFactory.createSocket(host, port));
    }

    @Override
    public Socket createSocket(InetAddress host, int port) throws IOException {
        return wrapSocket(defaultSSLSocketFactory.createSocket(host, port));
    }

    @Override
    public Socket createSocket(String host, int port, InetAddress localHost, int localPort)
            throws IOException, UnknownHostException {
        return wrapSocket(defaultSSLSocketFactory.createSocket(host, port, localHost, localPort));
    }

    @Override
    public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort)
            throws IOException {
        return wrapSocket(defaultSSLSocketFactory.createSocket(address, port, localAddress, localPort));
    }

    @Override
    public String[] getDefaultCipherSuites() {
        if (this.sslParams.getCipherSuites() != null) {
            return this.sslParams.getCipherSuites();
        }
        return defaultSSLSocketFactory.getDefaultCipherSuites();
    }

    @Override
    public String[] getSupportedCipherSuites() {
        return defaultSSLSocketFactory.getSupportedCipherSuites();
    }

    @Override
    public Socket createSocket() throws IOException {
        return wrapSocket(defaultSSLSocketFactory.createSocket());
    }

    /**
     * init(..) before call this method
     *
     * @return the default factory.
     */
    public static SocketFactory getDefault() {
        return makeSSLSocketFactory(trustStoreConfig);
    }

    private static TrustStoreConfig trustStoreConfig = null;

    /**
     * init required parameters for getDefault()
     */
    public static void init(TrustStoreConfig trustStoreConfig1) {
        trustStoreConfig = trustStoreConfig1;
    }

    private static void check(TrustStoreConfig trustStoreCfg) {
        if (trustStoreCfg == null) {
            throw new TlsInitException("SIMPLE_TLS_SOCK_FACTORY.PARAMS_REQUIRED", null,
                    "trust store config object.");
        }

        if (trustStoreCfg.getPassword() == null) {
            throw new TlsInitException("SIMPLE_TLS_SOCK_FACTORY.PARAMS_REQUIRED", null, "PasswordProvider");
        } else if (ArrayUtils.isEmpty(trustStoreCfg.getPassword().getPlainChars())) {
            throw new TlsInitException("SIMPLE_TLS_SOCK_FACTORY.PARAMS_REQUIRED", null, "Password");
        }

        if (CommonUtil.isBlank(trustStoreCfg.getPath())) {
            throw new TlsInitException("SIMPLE_TLS_SOCK_FACTORY.PARAMS_REQUIRED", null, "Trust Store Path");
        }

        if (CommonUtil.isBlank(trustStoreCfg.getType())) {
            throw new TlsInitException("SIMPLE_TLS_SOCK_FACTORY.PARAMS_REQUIRED", null, "Trust Store Type");
        }
    }

    /**
     * factory method for custom usage.
     *
     * @return a factory
     */
    public static SSLSocketFactory makeSSLSocketFactory(TrustStoreConfig trustStoreCfg) {
        check(trustStoreCfg);

        SimpleServerTrustManager simpleServerTrustManager = new SimpleServerTrustManager();
        simpleServerTrustManager.setTrustStoreConfig(trustStoreCfg);
        /**
         *  Initialize our own trust manager
         */
        TrustManager[] trustManagers = new TrustManager[] { simpleServerTrustManager };

        SSLContext customSSLContext = null;
        try {
            /**
             * Instantiate a context that implements the family of TLS protocols
             */
            customSSLContext = SSLContext.getInstance("TLS");

            /**
             * Initialize SSL context. Default instances of KeyManager and
             * SecureRandom are used.
             */
            customSSLContext.init(null, trustManagers, null);
        } catch (NoSuchAlgorithmException e) {
            throw new TlsInitException("SSLContext_INIT_ERR", e);
        } catch (KeyManagementException e) {
            throw new TlsInitException("SSLContext_INIT_ERR", e);
        }

        TlsClientConfiguration tlsClientConfiguration = new TlsClientConfiguration();
        /**
         * Build connection configuration and pass to socket
         */
        SSLParameters params = new SSLParameters();
        params.setCipherSuites(tlsClientConfiguration.getCipherSuites());
        params.setProtocols(tlsClientConfiguration.getSslProtocols());
        //      params.setEndpointIdentificationAlgorithm(
        //            config.getEndpointIdentificationAlgorithm());
        /**
         * Use the SSLSocketFactory generated by the SSLContext and wrap it to
         * enable custom cipher suites and protocols
         */
        return new SimpleSeverTrustTlsSocketFactory(customSSLContext.getSocketFactory(), params);
    }
}