org.springframework.data.keyvalue.redis.listener.RedisMessageListenerContainer.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.data.keyvalue.redis.listener.RedisMessageListenerContainer.java

Source

/*
 * Copyright 2011 the original author or authors.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.keyvalue.redis.listener;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanNameAware;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.data.keyvalue.redis.connection.Message;
import org.springframework.data.keyvalue.redis.connection.MessageListener;
import org.springframework.data.keyvalue.redis.connection.RedisConnection;
import org.springframework.data.keyvalue.redis.connection.RedisConnectionFactory;
import org.springframework.data.keyvalue.redis.connection.Subscription;
import org.springframework.data.keyvalue.redis.connection.util.ByteArrayWrapper;
import org.springframework.data.keyvalue.redis.serializer.RedisSerializer;
import org.springframework.data.keyvalue.redis.serializer.StringRedisSerializer;
import org.springframework.scheduling.SchedulingAwareRunnable;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ErrorHandler;

/**
 * Container providing asynchronous behaviour for Redis message listeners.
 * Handles the low level details of listening, converting and message dispatching.
 * <p/>
 * As oppose to the low level Redis (one connection per subscription), the container 
 * uses only one connection that is 'multiplexed' for all registered listeners, 
 * the message dispatch being done through the task executor.
 * 
 * <p/>
 * Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured). 
 * 
 * @author Costin Leau
 */
public class RedisMessageListenerContainer
        implements InitializingBean, DisposableBean, BeanNameAware, SmartLifecycle {

    /** Logger available to subclasses */
    protected final Log logger = LogFactory.getLog(getClass());

    /**
     * Default thread name prefix: "RedisListeningContainer-".
     */
    public static final String DEFAULT_THREAD_NAME_PREFIX = ClassUtils
            .getShortName(RedisMessageListenerContainer.class) + "-";

    private long initWait = TimeUnit.SECONDS.toMillis(5);

    private Executor subscriptionExecutor;

    private Executor taskExecutor;

    private RedisConnectionFactory connectionFactory;

    private String beanName;

    private ErrorHandler errorHandler;

    private final Object monitor = new Object();
    // whether the container is running (or not)
    private volatile boolean running = false;
    // whether the container has been initialized
    private volatile boolean initialized = false;
    // whether the container uses a connection or not
    // (as the container might be running but w/o listeners, it won't use any resources)
    private volatile boolean listening = false;

    private volatile boolean manageExecutor = false;

    // lookup maps
    // to avoid creation of hashes for each message, the maps use raw byte arrays (wrapped to respect the equals/hashcode contract)

    // lookup map between patterns and listeners
    private final Map<ByteArrayWrapper, Collection<MessageListener>> patternMapping = new ConcurrentHashMap<ByteArrayWrapper, Collection<MessageListener>>();
    // lookup map between channels and listeners
    private final Map<ByteArrayWrapper, Collection<MessageListener>> channelMapping = new ConcurrentHashMap<ByteArrayWrapper, Collection<MessageListener>>();

    private final SubscriptionTask subscriptionTask = new SubscriptionTask();

    private volatile RedisSerializer<String> serializer = new StringRedisSerializer();

    @Override
    public void afterPropertiesSet() {
        if (taskExecutor == null) {
            manageExecutor = true;
            taskExecutor = createDefaultTaskExecutor();
        }

        if (subscriptionExecutor == null) {
            subscriptionExecutor = taskExecutor;
        }

        initialized = true;

        start();
    }

    /**
     * Creates a default TaskExecutor. Called if no explicit TaskExecutor has been specified.
     * <p>The default implementation builds a {@link org.springframework.core.task.SimpleAsyncTaskExecutor}
     * with the specified bean name (or the class name, if no bean name specified) as thread name prefix.
     * @see org.springframework.core.task.SimpleAsyncTaskExecutor#SimpleAsyncTaskExecutor(String)
     */
    protected TaskExecutor createDefaultTaskExecutor() {
        String threadNamePrefix = (beanName != null ? beanName + "-" : DEFAULT_THREAD_NAME_PREFIX);
        return new SimpleAsyncTaskExecutor(threadNamePrefix);
    }

    @Override
    public void destroy() throws Exception {
        initialized = false;

        stop();

        if (manageExecutor) {
            if (taskExecutor instanceof DisposableBean) {
                ((DisposableBean) taskExecutor).destroy();

                if (logger.isDebugEnabled()) {
                    logger.debug("Stopped internally-managed task executor");
                }
            }
        }
    }

    @Override
    public boolean isAutoStartup() {
        return true;
    }

    @Override
    public void stop(Runnable callback) {
        stop();
        callback.run();
    }

    @Override
    public int getPhase() {
        // start the latest
        return Integer.MAX_VALUE;
    }

    @Override
    public boolean isRunning() {
        return running;
    }

    @Override
    public void start() {
        if (!running) {
            running = true;
            // wait for the subscription to start before returning
            // technically speaking we can only be notified right before the subscription starts
            synchronized (monitor) {
                lazyListen();
                try {
                    // wait up to 5 seconds
                    monitor.wait(initWait);
                } catch (InterruptedException e) {
                    // stop waiting
                }
            }

            if (logger.isDebugEnabled()) {
                logger.debug("Started RedisMessageListenerContainer");
            }
        }
    }

    @Override
    public void stop() {
        if (isRunning()) {
            running = false;
            synchronized (monitor) {
                subscriptionTask.cancel();
                if (listening) {
                    try {
                        monitor.wait(initWait);
                    } catch (InterruptedException ex) {
                        // stop waiting
                    }
                }
            }
        }

        if (logger.isDebugEnabled()) {
            logger.debug("Stopped RedisMessageListenerContainer");
        }
    }

    /**
     * Process a message received from the provider.
     * 
     * @param message
     * @param pattern
     */
    protected void processMessage(MessageListener listener, Message message, byte[] pattern) {
        executeListener(listener, message, pattern);
    }

    /**
     * Execute the specified listener.
     * 
     * @see #handleListenerException
     */
    protected void executeListener(MessageListener listener, Message message, byte[] pattern) {
        try {
            listener.onMessage(message, pattern);
        } catch (Throwable ex) {
            handleListenerException(ex);
        }
    }

    /**
     * Return whether this container is currently active,
     * that is, whether it has been set up but not shut down yet.
     */
    public final boolean isActive() {
        return initialized;
    }

    /**
     * Handle the given exception that arose during listener execution.
     * <p>The default implementation logs the exception at error level.
     * This can be overridden in subclasses.
     * @param ex the exception to handle
     */
    protected void handleListenerException(Throwable ex) {
        if (isActive()) {
            // Regular case: failed while active.
            // Invoke ErrorHandler if available.
            invokeErrorHandler(ex);
        } else {
            // Rare case: listener thread failed after container shutdown.
            // Log at debug level, to avoid spamming the shutdown logger.
            logger.debug("Listener exception after container shutdown", ex);
        }
    }

    /**
     * Invoke the registered ErrorHandler, if any. Log at error level otherwise.
     * @param ex the uncaught error that arose during message processing.
     * @see #setErrorHandler
     */
    protected void invokeErrorHandler(Throwable ex) {
        if (this.errorHandler != null) {
            this.errorHandler.handleError(ex);
        } else if (logger.isWarnEnabled()) {
            logger.warn("Execution of JMS message listener failed, and no ErrorHandler has been set.", ex);
        }
    }

    /**
     * Returns the connectionFactory.
     *
     * @return Returns the connectionFactory
     */
    public RedisConnectionFactory getConnectionFactory() {
        return connectionFactory;
    }

    /**
     * @param connectionFactory The connectionFactory to set.
     */
    public void setConnectionFactory(RedisConnectionFactory connectionFactory) {
        this.connectionFactory = connectionFactory;
    }

    @Override
    public void setBeanName(String name) {
        this.beanName = name;
    }

    /**
     * Sets the task executor used for running the message listeners when messages are received.
     * If no task executor is set, an instance of {@link SimpleAsyncTaskExecutor} will be used by default.
     * The task executor can be adjusted depending on the work done by the listeners and the number of 
     * messages coming in.
     * 
     * @param taskExecutor The taskExecutor to set.
     */
    public void setTaskExecutor(Executor taskExecutor) {
        this.taskExecutor = taskExecutor;
    }

    /**
     * Sets the task execution used for subscribing to Redis channels. By default, if no executor is set,
     * the {@link #setTaskExecutor(Executor)} will be used. In some cases, this might be undersired as
     * the listening to the connection is a long running task.
     *
     * <p/>Note: This implementation uses at most one long running thread (depending on whether there are any listeners registered or not)
     * and up to two threads during the initial registration. 
     * 
     * @param subscriptionExecutor The subscriptionExecutor to set.
     */
    public void setSubscriptionExecutor(Executor subscriptionExecutor) {
        this.subscriptionExecutor = subscriptionExecutor;
    }

    /**
     * Sets the serializer for converting the {@link Topic}s into low-level channels and patterns.
     * By default, {@link StringRedisSerializer} is used.
     * 
     * @param serializer The serializer to set.
     */
    public void setTopicSerializer(RedisSerializer<String> serializer) {
        this.serializer = serializer;
    }

    /**
     * Set an ErrorHandler to be invoked in case of any uncaught exceptions thrown
     * while processing a Message. By default there will be <b>no</b> ErrorHandler
     * so that error-level logging is the only result.
     */
    public void setErrorHandler(ErrorHandler errorHandler) {
        this.errorHandler = errorHandler;
    }

    /**
     * Attaches the given listeners (and their topics) to the container.
     * 
     * <p/>
     * Note: it's possible to call this method while the container is running forcing a reinitialization
     * of the container. Note however that this might cause some messages to be lost (while the container
     * reinitializes) - hence calling this method at runtime is considered advanced usage.
     * 
     * @param listeners map of message listeners and their associated topics
     */
    public void setMessageListeners(Map<? extends MessageListener, Collection<? extends Topic>> listeners) {
        initMapping(listeners);
    }

    /**
     * Adds a message listener to the (potentially running) container. If the container is running,
     * the listener starts receiving (matching) messages as soon as possible.
     * 
     * @param listener message listener
     * @param topics message listener topic
     */
    public void addMessageListener(MessageListener listener, Collection<? extends Topic> topics) {
        addListener(listener, topics);
        lazyListen();
    }

    /**
     * Adds a message listener to the (potentially running) container. If the container is running,
     * the listener starts receiving (matching) messages as soon as possible.
     * 
     * @param listener message listener
     * @param topic message topic
     */
    public void addMessageListener(MessageListener listener, Topic topic) {
        addMessageListener(listener, Collections.singleton(topic));
    }

    private void initMapping(Map<? extends MessageListener, Collection<? extends Topic>> listeners) {
        // stop the listener if currently running
        if (isRunning()) {
            stop();
        }

        patternMapping.clear();
        channelMapping.clear();

        if (!CollectionUtils.isEmpty(listeners)) {
            for (Map.Entry<? extends MessageListener, Collection<? extends Topic>> entry : listeners.entrySet()) {
                addListener(entry.getKey(), entry.getValue());
            }
        }

        // resume activity
        if (initialized) {
            start();
        }
    }

    /**
     * Method inspecting whether listening for messages (and thus using a thread) is actually needed and triggering it.
     */
    private void lazyListen() {
        boolean debug = logger.isDebugEnabled();
        boolean started = false;

        if (isRunning()) {
            if (!listening) {
                synchronized (monitor) {
                    if (!listening) {
                        if (channelMapping.size() > 0 || patternMapping.size() > 0) {
                            subscriptionExecutor.execute(subscriptionTask);
                            listening = true;
                            started = true;
                        }
                    }
                }
                if (debug) {
                    if (started) {
                        logger.debug("Started listening for Redis messages");
                    } else {
                        logger.debug("Postpone listening for Redis messages until actual listeners are added");
                    }
                }
            }
        }
    }

    private void addListener(MessageListener listener, Collection<? extends Topic> topics) {
        List<byte[]> channels = new ArrayList<byte[]>(topics.size());
        List<byte[]> patterns = new ArrayList<byte[]>(topics.size());

        boolean trace = logger.isTraceEnabled();

        for (Topic topic : topics) {

            ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic()));

            if (topic instanceof ChannelTopic) {
                Collection<MessageListener> collection = channelMapping.get(holder);
                if (collection == null) {
                    collection = new CopyOnWriteArraySet<MessageListener>();
                    channelMapping.put(holder, collection);
                }
                collection.add(listener);
                channels.add(holder.getArray());

                if (trace)
                    logger.trace("Adding listener '" + listener + "' on channel '" + topic.getTopic() + "'");
            }

            else if (topic instanceof PatternTopic) {
                Collection<MessageListener> collection = patternMapping.get(holder);
                if (collection == null) {
                    collection = new CopyOnWriteArraySet<MessageListener>();
                    patternMapping.put(holder, collection);
                }
                collection.add(listener);
                patterns.add(holder.getArray());

                if (trace)
                    logger.trace("Adding listener '" + listener + "' for pattern '" + topic.getTopic() + "'");
            }

            else {
                throw new IllegalArgumentException("Unknown topic type '" + topic.getClass() + "'");
            }
        }

        // check the current listening state
        if (listening) {
            subscriptionTask.subscribeChannel(channels.toArray(new byte[channels.size()][]));
            subscriptionTask.subscribePattern(patterns.toArray(new byte[patterns.size()][]));
        }
    }

    /**
     * Runnable used for Redis subscription. Implemented as a dedicated class to provide as many hints
     * as possible to the underlying thread pool.
     * 
     * @author Costin Leau
     */
    private class SubscriptionTask implements SchedulingAwareRunnable {

        /**
         * Runnable used, on a parallel thread, to do the initial pSubscribe.
         * This is required since, during initialization, both subscribe and pSubscribe
         * might be needed but since the first call is blocking, the second call needs to
         * executed in parallel.
         *  
         * @author Costin Leau
         */
        private class PatternSubscriptionTask implements SchedulingAwareRunnable {

            private long WAIT = 500;
            private long ROUNDS = 3;

            @Override
            public boolean isLongLived() {
                return false;
            }

            @Override
            public void run() {
                // wait for subscription to be initialized
                boolean done = false;
                // wait 3 rounds for subscription to be initialized
                for (int i = 0; i < ROUNDS || done; i++) {
                    if (connection != null) {
                        synchronized (localMonitor) {
                            if (connection != null && connection.isSubscribed()) {
                                done = true;
                                connection.getSubscription().pSubscribe(unwrap(patternMapping.keySet()));
                            } else {
                                try {
                                    Thread.sleep(WAIT);
                                } catch (InterruptedException ex) {
                                    done = true;
                                }
                            }
                        }
                    }
                }
            }
        }

        private volatile RedisConnection connection;
        private final Object localMonitor = new Object();

        @Override
        public boolean isLongLived() {
            return true;
        }

        @Override
        public void run() {
            connection = connectionFactory.getConnection();
            try {
                if (connection.isSubscribed()) {
                    throw new IllegalStateException(
                            "Retrieved connection is already subscribed; aborting listening");
                }

                // NB: each Xsubscribe call blocks

                synchronized (monitor) {
                    monitor.notify();
                }

                // subscribe one way or the other
                // and schedule the rest
                if (!channelMapping.isEmpty()) {
                    // schedule the rest of the subscription
                    if (!patternMapping.isEmpty()) {
                        subscriptionExecutor.execute(new PatternSubscriptionTask());
                    }
                    connection.subscribe(new DispatchMessageListener(), unwrap(channelMapping.keySet()));
                } else {
                    connection.pSubscribe(new DispatchMessageListener(), unwrap(patternMapping.keySet()));
                }

            } finally {
                // this block is executed once the subscription has ended
                // meaning cleanup is required

                listening = false;

                if (connection != null) {
                    synchronized (localMonitor) {
                        if (connection != null) {
                            connection.close();
                            connection = null;
                        }
                    }
                }

                // done with the thread, app can be destroyed
                synchronized (monitor) {
                    monitor.notify();
                }

            }
        }

        private byte[][] unwrap(Collection<ByteArrayWrapper> holders) {
            if (CollectionUtils.isEmpty(holders)) {
                return new byte[0][];
            }

            byte[][] unwrapped = new byte[holders.size()][];

            int index = 0;
            for (ByteArrayWrapper arrayHolder : holders) {
                unwrapped[index++] = arrayHolder.getArray();
            }

            return unwrapped;
        }

        void cancel() {
            if (connection != null) {
                synchronized (localMonitor) {
                    if (connection != null) {
                        Subscription sub = connection.getSubscription();
                        if (sub != null) {
                            sub.pUnsubscribe();
                            sub.unsubscribe();
                        }
                    }
                }
            }
        }

        void subscribeChannel(byte[]... channels) {
            if (channels != null && channels.length > 0) {
                if (connection != null) {
                    synchronized (localMonitor) {
                        if (connection != null) {
                            Subscription sub = connection.getSubscription();
                            if (sub != null) {
                                sub.subscribe(channels);
                            }
                        }
                    }
                }
            }
        }

        void subscribePattern(byte[]... patterns) {
            if (patterns != null && patterns.length > 0) {
                if (connection != null) {
                    synchronized (localMonitor) {
                        if (connection != null) {
                            Subscription sub = connection.getSubscription();
                            if (sub != null) {
                                sub.pSubscribe(patterns);
                            }
                        }
                    }
                }
            }
        }

        void unsubscribeChannel(byte[]... channels) {
            if (channels != null && channels.length > 0) {
                if (connection != null) {
                    synchronized (localMonitor) {
                        if (connection != null) {
                            Subscription sub = connection.getSubscription();
                            if (sub != null) {
                                sub.unsubscribe(channels);
                            }
                        }
                    }
                }
            }
        }

        void unsubscribePattern(byte[]... patterns) {
            if (patterns != null && patterns.length > 0) {
                if (connection != null) {
                    synchronized (localMonitor) {
                        if (connection != null) {
                            Subscription sub = connection.getSubscription();
                            if (sub != null) {
                                sub.pUnsubscribe(patterns);
                            }
                        }
                    }
                }
            }
        }
    }

    /**
     * Actual message dispatcher/multiplexer.
     * 
     * @author Costin Leau
     */
    private class DispatchMessageListener implements MessageListener {

        @Override
        public void onMessage(Message message, byte[] pattern) {
            // do channel matching first
            byte[] channel = message.getChannel();

            Collection<MessageListener> ch = channelMapping.get(new ByteArrayWrapper(channel));
            Collection<MessageListener> pt = null;

            // followed by pattern matching
            if (pattern != null && pattern.length > 0) {
                pt = patternMapping.get(new ByteArrayWrapper(pattern));
            }

            if (!CollectionUtils.isEmpty(ch)) {
                dispatchChannels(ch, message);
            }

            if (!CollectionUtils.isEmpty(pt)) {
                dispatchPatterns(pt, message, pattern);
            }
        }

        private void dispatchChannels(Collection<MessageListener> ch, final Message message) {
            for (final MessageListener messageListener : ch) {
                taskExecutor.execute(new Runnable() {
                    @Override
                    public void run() {
                        processMessage(messageListener, message, null);
                    }
                });
            }
        }

        private void dispatchPatterns(Collection<MessageListener> pt, final Message message, final byte[] pattern) {
            for (final MessageListener messageListener : pt) {
                taskExecutor.execute(new Runnable() {
                    @Override
                    public void run() {
                        processMessage(messageListener, message, pattern.clone());
                    }
                });
            }
        }
    }
}