com.kixeye.chassis.transport.websocket.ActionInvokingWebSocket.java Source code

Java tutorial

Introduction

Here is the source code for com.kixeye.chassis.transport.websocket.ActionInvokingWebSocket.java

Source

package com.kixeye.chassis.transport.websocket;

/*
 * #%L
 * Chassis Transport Core
 * %%
 * Copyright (C) 2014 KIXEYE, Inc
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import javax.validation.Validator;

import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler;

import com.google.common.base.Charsets;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.cache.RemovalListener;
import com.google.common.cache.RemovalNotification;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.kixeye.chassis.transport.ExceptionServiceErrorMapper;
import com.kixeye.chassis.transport.dto.Envelope;
import com.kixeye.chassis.transport.dto.ServiceError;
import com.kixeye.chassis.transport.serde.MessageSerDe;

/**
 * Listens to websocket messages and forwards it to the correct bean.
 * 
 * @author ebahtijaragic
 */
public class ActionInvokingWebSocket implements WebSocketListener {
    private static final Logger logger = LoggerFactory.getLogger(ActionInvokingWebSocket.class);

    @Autowired
    private WebSocketMessageMappingRegistry mappingRegistry;

    @Autowired
    private WebSocketMessageRegistry messageRegistry;

    @Autowired
    private DefaultListableBeanFactory beanFactory;

    @Autowired
    private Validator messageValidator;

    @Autowired
    private WebSocketPskFrameProcessor pskFrameProcessor;

    private Session session;

    private WebSocketSession webSocketSession = new WebSocketSession(this);

    private ServletUpgradeRequest upgradeRequest;

    private ServletUpgradeResponse upgradeResponse;

    private MessageSerDe serDe;

    private ListeningExecutorService serviceExecutor = MoreExecutors
            .listeningDecorator(Executors.newSingleThreadExecutor());
    private ExecutorService responseExecutor = Executors.newSingleThreadExecutor();

    private LoadingCache<String, Object> handlerCache = CacheBuilder.newBuilder()
            .removalListener(new RemovalListener<String, Object>() {
                public void onRemoval(RemovalNotification<String, Object> notification) {
                    Class<?> handlerClass = null;

                    try {
                        handlerClass = Class.forName(notification.getKey());
                    } catch (ClassNotFoundException e) {
                        logger.error("Unexpected exception", e);
                    }

                    if (handlerClass != null) {
                        String[] beanNames = beanFactory.getBeanNamesForType(handlerClass);

                        if (beanNames != null && beanNames.length > 0) {
                            if (beanFactory.isPrototype(beanNames[0])) {
                                if (notification.getValue() instanceof WebSocketSessionAware) {
                                    WebSocketSessionAware webSocketSessionAwareHandler = (WebSocketSessionAware) notification
                                            .getValue();

                                    webSocketSessionAwareHandler.onWebSocketSessionRemoved(webSocketSession);
                                }

                                beanFactory.destroyBean(notification.getValue());
                            } // else this is a singleton and we don't do anything with singletons
                        } // this shouldn't happen
                    } // this shouldn't happen either
                }
            }).build(new CacheLoader<String, Object>() {
                public Object load(String handlerClassName) throws Exception {
                    Class<?> handlerClass = Class.forName(handlerClassName);

                    String[] beanNames = beanFactory.getBeanNamesForType(handlerClass);

                    if (beanNames != null && beanNames.length > 0) {
                        Object handler = beanFactory.getBean(beanNames[0]);

                        if (handler instanceof WebSocketSessionAware) {
                            WebSocketSessionAware webSocketSessionAwareHandler = (WebSocketSessionAware) handler;

                            webSocketSessionAwareHandler.onWebSocketSessionCreated(webSocketSession);
                        }

                        return handler;
                    } else {
                        throw new RuntimeException("No beans exist for handler: " + handlerClass);
                    }
                }
            });

    public void onWebSocketBinary(byte[] payload, int offset, int length) {
        try {
            // don't accept empty frames
            if (payload == null || length < 1) {
                throw new WebSocketServiceException(new ServiceError("EMPTY_ENVELOPE", "Empty envelope!"),
                        "UNKNOWN", null);
            }

            // check if we need to do psk encryption
            byte[] processedPayload = pskFrameProcessor.processIncoming(payload, offset, length);

            if (processedPayload != payload) {
                payload = processedPayload;
                offset = 0;
                length = payload.length;
            }

            // get the envelope
            final WebSocketEnvelope envelope = new WebSocketEnvelope(
                    serDe.deserialize(payload, offset, length, Envelope.class));

            // gets all the actions
            Collection<WebSocketAction> actions = mappingRegistry.getActionMethods(envelope.getAction());

            final AtomicInteger invokedActions = new AtomicInteger(0);

            // invokes them
            for (final WebSocketAction action : actions) {
                // get and validate type ID
                Class<?> messageClass = null;

                if (StringUtils.isNotBlank(envelope.getTypeId())) {
                    messageClass = messageRegistry.getClassByTypeId(envelope.getTypeId());
                }

                // validate if action has a payload class that it needs
                if (action.getPayloadClass() != null && messageClass == null) {
                    throw new WebSocketServiceException(new ServiceError("INVALID_TYPE_ID", "Unknown type ID!"),
                            envelope.getAction(), envelope.getTransactionId());
                }

                // invoke this action if allowed
                if (action.canInvoke(webSocketSession, messageClass)) {
                    invokedActions.incrementAndGet();

                    final Object handler = handlerCache.get(action.getHandlerClass().getName());
                    final Class<?> finalMessageClass = messageClass;

                    ListenableFuture<DeferredResult<?>> invocation = serviceExecutor
                            .submit(new Callable<DeferredResult<?>>() {
                                @Override
                                public DeferredResult<?> call() throws Exception {
                                    // then invoke
                                    return action.invoke(
                                            handler, new RawWebSocketMessage<>(envelope.getPayload(),
                                                    finalMessageClass, messageValidator, serDe),
                                            envelope, webSocketSession);
                                }
                            });

                    Futures.addCallback(invocation, new FutureCallback<DeferredResult<?>>() {
                        public void onSuccess(DeferredResult<?> result) {
                            if (result != null) {
                                result.setResultHandler(new DeferredResultHandler() {
                                    @Override
                                    public void handleResult(Object result) {
                                        if (result instanceof Exception) {
                                            onFailure((Exception) result);
                                            return;
                                        }

                                        sendResponse(result);
                                    }
                                });
                            }
                        }

                        public void onFailure(Throwable t) {
                            if (t instanceof InvocationTargetException) {
                                t = ((InvocationTargetException) t).getTargetException();
                            }

                            ServiceError error = ExceptionServiceErrorMapper.mapException(t);

                            if (error != null
                                    && !ExceptionServiceErrorMapper.VALIDATION_ERROR_CODE.equals(error.code)) {
                                logger.error("Unexpected exception throw while executing action [{}]",
                                        envelope.getAction(), t);
                            }

                            sendResponse(error);
                        }

                        public Future<Void> sendResponse(Object response) {
                            try {
                                return sendMessage(envelope.getAction(), envelope.getTransactionId(), response);
                            } catch (IOException | GeneralSecurityException e) {
                                logger.error("Unable to send message to channel", e);

                                return Futures.immediateFuture(null);
                            }
                        }

                    }, responseExecutor);
                }
            }

            // make sure we actually invoked something
            if (invokedActions.get() < 1) {
                throw new WebSocketServiceException(
                        new ServiceError("INVALID_ACTION_MAPPING", "No actions invoked."), envelope.getAction(),
                        envelope.getTransactionId());
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void onWebSocketConnect(Session session) {
        logger.info(this.toString() + " - Session connected [{}].", session.toString());
        this.session = session;
    }

    public void onWebSocketError(Throwable cause) {
        logger.error("Unexpected socket error", cause);

        if (session.isOpen()) {
            try {
                String action = null;
                String txId = null;

                WebSocketServiceException serviceException = null;
                Throwable currentCause = cause;

                while (currentCause != null) {
                    if (currentCause instanceof WebSocketServiceException) {
                        serviceException = (WebSocketServiceException) currentCause;
                    }

                    currentCause = currentCause.getCause();
                }

                if (serviceException != null) {
                    action = serviceException.action;
                    txId = serviceException.transactionId;
                } else {
                    action = "UNKNOWN";
                }

                sendMessage(action, txId, ExceptionServiceErrorMapper
                        .mapException(serviceException == null ? cause : serviceException)).get();
            } catch (Exception e) {
                logger.error("Unexpected error", e);
            }
        }
    }

    public void onWebSocketText(String message) {
        byte[] data = message.getBytes(Charsets.UTF_8);

        onWebSocketBinary(data, 0, data.length);
    }

    public synchronized void onWebSocketClose(int statusCode, String reason) {
        logger.info(this.toString() + " - Session disconnected [{}]. Reason: [{}]", session.toString(), reason);

        try {
            handlerCache.invalidateAll();
        } finally {
            beanFactory.destroyBean(this);
        }
    }

    /**
     * Gets the websocket session.
     * 
     * @return
     * @throws IOException 
     * @throws GeneralSecurityException 
     */
    protected Future<Void> sendMessage(String action, String transactionId, Object obj)
            throws IOException, GeneralSecurityException {
        String typeId = messageRegistry.getTypeIdByClass(obj.getClass());

        if (typeId == null) {
            throw new RuntimeException("Unable to determine type ID for class: " + obj.getClass());
        }

        byte[] payload = serDe.serialize(obj);

        return sendMessage(action, transactionId, typeId, ByteBuffer.wrap(payload));
    }

    /**
     * Gets the websocket session.
     * 
     * @return
     * @throws IOException 
     * @throws GeneralSecurityException 
     */
    protected Future<Void> sendMessage(String action, String transactionId, String typeId, ByteBuffer payload)
            throws IOException, GeneralSecurityException {
        Envelope envelope = new Envelope(action, typeId, transactionId, payload);

        // generate blob
        byte[] envelopeBlob = serDe.serialize(envelope);

        // check if we need to do psk encryption
        envelopeBlob = pskFrameProcessor.processOutgoing(envelopeBlob, 0, envelopeBlob.length);

        return session.getRemote().sendBytesByFuture(ByteBuffer.wrap(envelopeBlob));
    }

    /**
     * Gets the websocket session.
     * 
     * @return
     * @throws IOException 
     * @throws GeneralSecurityException 
     */
    protected Future<Void> sendContent(InputStream inputStream) throws IOException, GeneralSecurityException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();

        IOUtils.copyLarge(inputStream, baos);

        // generate blob
        byte[] contentBlob = serDe.serialize(baos.toByteArray());

        // check if we need to do psk encryption
        contentBlob = pskFrameProcessor.processOutgoing(contentBlob, 0, contentBlob.length);

        return session.getRemote().sendBytesByFuture(ByteBuffer.wrap(contentBlob));
    }

    /**
     * @return the serDe
     */
    public MessageSerDe getSerDe() {
        return serDe;
    }

    /**
     * @param serDe the serDe to set
     */
    public void setSerDe(MessageSerDe serDe) {
        this.serDe = serDe;
    }

    /**
     * @return the upgradeRequest
     */
    public ServletUpgradeRequest getUpgradeRequest() {
        return upgradeRequest;
    }

    /**
     * @param upgradeRequest the upgradeRequest to set
     */
    public void setUpgradeRequest(ServletUpgradeRequest upgradeRequest) {
        this.upgradeRequest = upgradeRequest;
    }

    /**
     * @return the upgradeResponse
     */
    public ServletUpgradeResponse getUpgradeResponse() {
        return upgradeResponse;
    }

    /**
     * @param upgradeResponse the upgradeResponse to set
     */
    public void setUpgradeResponse(ServletUpgradeResponse upgradeResponse) {
        this.upgradeResponse = upgradeResponse;
    }

    /**
     * Returns true if we're connected.
     * 
     * @return
     */
    public boolean isConnected() {
        return session != null && session.isOpen();
    }
}