org.apache.cassandra.security.SSLFactory.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.cassandra.security.SSLFactory.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.security;

import java.io.FileInputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.security.KeyStore;
import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.Enumeration;
import java.util.Set;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;

import org.apache.cassandra.config.EncryptionOptions;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Sets;

/**
 * A Factory for providing and setting up Client and Server SSL wrapped
 * Socket and ServerSocket
 */
public final class SSLFactory {
    private static final Logger logger = LoggerFactory.getLogger(SSLFactory.class);
    public static final String[] ACCEPTED_PROTOCOLS = new String[] { "SSLv2Hello", "TLSv1", "TLSv1.1", "TLSv1.2" };
    private static boolean checkedExpiry = false;

    public static SSLServerSocket getServerSocket(EncryptionOptions options, InetAddress address, int port)
            throws IOException {
        SSLContext ctx = createSSLContext(options, true);
        SSLServerSocket serverSocket = (SSLServerSocket) ctx.getServerSocketFactory().createServerSocket();
        serverSocket.setReuseAddress(true);
        String[] suits = filterCipherSuites(serverSocket.getSupportedCipherSuites(), options.cipher_suites);
        serverSocket.setEnabledCipherSuites(suits);
        serverSocket.setNeedClientAuth(options.require_client_auth);
        serverSocket.setEnabledProtocols(ACCEPTED_PROTOCOLS);
        serverSocket.bind(new InetSocketAddress(address, port), 500);
        return serverSocket;
    }

    /** Create a socket and connect */
    public static SSLSocket getSocket(EncryptionOptions options, InetAddress address, int port,
            InetAddress localAddress, int localPort) throws IOException {
        SSLContext ctx = createSSLContext(options, true);
        SSLSocket socket = (SSLSocket) ctx.getSocketFactory().createSocket(address, port, localAddress, localPort);
        String[] suits = filterCipherSuites(socket.getSupportedCipherSuites(), options.cipher_suites);
        socket.setEnabledCipherSuites(suits);
        socket.setEnabledProtocols(ACCEPTED_PROTOCOLS);
        return socket;
    }

    /** Create a socket and connect, using any local address */
    public static SSLSocket getSocket(EncryptionOptions options, InetAddress address, int port) throws IOException {
        SSLContext ctx = createSSLContext(options, true);
        SSLSocket socket = (SSLSocket) ctx.getSocketFactory().createSocket(address, port);
        String[] suits = filterCipherSuites(socket.getSupportedCipherSuites(), options.cipher_suites);
        socket.setEnabledCipherSuites(suits);
        socket.setEnabledProtocols(ACCEPTED_PROTOCOLS);
        return socket;
    }

    /** Just create a socket */
    public static SSLSocket getSocket(EncryptionOptions options) throws IOException {
        SSLContext ctx = createSSLContext(options, true);
        SSLSocket socket = (SSLSocket) ctx.getSocketFactory().createSocket();
        String[] suits = filterCipherSuites(socket.getSupportedCipherSuites(), options.cipher_suites);
        socket.setEnabledCipherSuites(suits);
        socket.setEnabledProtocols(ACCEPTED_PROTOCOLS);
        return socket;
    }

    @SuppressWarnings("resource")
    public static SSLContext createSSLContext(EncryptionOptions options, boolean buildTruststore)
            throws IOException {
        FileInputStream tsf = null;
        FileInputStream ksf = null;
        SSLContext ctx;
        try {
            ctx = SSLContext.getInstance(options.protocol);
            TrustManager[] trustManagers = null;

            if (buildTruststore) {
                tsf = new FileInputStream(options.truststore);
                TrustManagerFactory tmf = TrustManagerFactory.getInstance(options.algorithm);
                KeyStore ts = KeyStore.getInstance(options.store_type);
                ts.load(tsf, options.truststore_password.toCharArray());
                tmf.init(ts);
                trustManagers = tmf.getTrustManagers();
            }

            ksf = new FileInputStream(options.keystore);
            KeyManagerFactory kmf = KeyManagerFactory.getInstance(options.algorithm);
            KeyStore ks = KeyStore.getInstance(options.store_type);
            ks.load(ksf, options.keystore_password.toCharArray());
            if (!checkedExpiry) {
                for (Enumeration<String> aliases = ks.aliases(); aliases.hasMoreElements();) {
                    String alias = aliases.nextElement();
                    if (ks.getCertificate(alias).getType().equals("X.509")) {
                        Date expires = ((X509Certificate) ks.getCertificate(alias)).getNotAfter();
                        if (expires.before(new Date()))
                            logger.warn("Certificate for {} expired on {}", alias, expires);
                    }
                }
                checkedExpiry = true;
            }
            kmf.init(ks, options.keystore_password.toCharArray());

            ctx.init(kmf.getKeyManagers(), trustManagers, null);

        } catch (Exception e) {
            throw new IOException("Error creating the initializing the SSL Context", e);
        } finally {
            FileUtils.closeQuietly(tsf);
            FileUtils.closeQuietly(ksf);
        }
        return ctx;
    }

    private static String[] filterCipherSuites(String[] supported, String[] desired) {
        Set<String> des = Sets.newHashSet(desired);
        Set<String> toReturn = Sets.intersection(Sets.newHashSet(supported), des);
        if (des.size() > toReturn.size())
            logger.warn("Filtering out {} as it isnt supported by the socket",
                    StringUtils.join(Sets.difference(des, toReturn), ","));
        return toReturn.toArray(new String[toReturn.size()]);
    }
}