com.openteach.diamond.network.waverider.session.DefaultSession.java Source code

Java tutorial

Introduction

Here is the source code for com.openteach.diamond.network.waverider.session.DefaultSession.java

Source

/**
 * Copyright 2013 openteach
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package com.openteach.diamond.network.waverider.session;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.openteach.diamond.network.waverider.SlaveWorker;
import com.openteach.diamond.network.waverider.command.Command;
import com.openteach.diamond.network.waverider.command.CommandDispatcher;
import com.openteach.diamond.network.waverider.command.CommandFactory;
import com.openteach.diamond.network.waverider.command.exception.ExecuteCommandException;
import com.openteach.diamond.network.waverider.common.WaveriderThreadFactory;
import com.openteach.diamond.network.waverider.network.NetWorkConstants;
import com.openteach.diamond.network.waverider.network.NetWorkServer;
import com.openteach.diamond.network.waverider.network.Packet;

/**
 * <p>
 * Session, Master?MasterSlave?Session, Session?Slave,
 * Session
 * </p>
 * 
 * @author <a href="mailto:sihai@taobao.com">sihai</a>
 *
 */
public class DefaultSession implements Session {

    private static final Log logger = LogFactory.getLog(DefaultSession.class);
    private static final String SESSION_THREAD_NAME_PREFIX = "Waverider-Session";

    private Long id; // Sessio ID
    private volatile SessionStateEnum state; // Session ?, ??
    private NetWorkServer netWorkServer; // 
    private SocketChannel channel; // ?
    private Thread thread; // 
    private SlaveWorker slaveWorker; // Slave?
    private int inBufferSize; // ??
    private int outBufferSize; // ??
    private BlockingQueue<ByteBuffer> inputBuffer; // ?
    private BlockingQueue<Command> outputBuffer; // ?
    private byte[] waitMoreDataLock; // 
    private CommandDispatcher commandDispatcher; // ?

    //
    private ReentrantLock runLock; //
    private Condition run; //
    private boolean isRun; //

    private WaveriderThreadFactory threadFactory;

    public DefaultSession(Long id, int inBufferSize, int outBufferSize) {
        this.id = id;
        this.state = SessionStateEnum.WAVERIDER_SESSION_STATE_FREE;
        this.inBufferSize = inBufferSize;
        this.outBufferSize = outBufferSize;
        this.runLock = new ReentrantLock();
        this.run = runLock.newCondition();
        this.isRun = false;
        threadFactory = new WaveriderThreadFactory(SESSION_THREAD_NAME_PREFIX + "-" + String.valueOf(id), null,
                true);
    }

    public Long getId() {
        return id;
    }

    //==============================================================
    //            LifeCycle
    //==============================================================
    @Override
    public boolean init() {
        state = SessionStateEnum.WAVERIDER_SESSION_STATE_FREE;
        inputBuffer = new LinkedBlockingQueue<ByteBuffer>(inBufferSize);
        outputBuffer = new LinkedBlockingQueue<Command>(outBufferSize);
        waitMoreDataLock = new byte[0];
        thread = threadFactory.newThread(new SessionTask());
        thread.start();
        return true;
    }

    @Override
    public boolean start() {
        if (channel == null || inputBuffer == null || outputBuffer == null || commandDispatcher == null) {
            throw new IllegalArgumentException("Seesion not set corrected");
        }
        control(true);
        state = SessionStateEnum.WAVERIDER_SESSION_STATE_ALIVE;
        return true;
    }

    @Override
    public boolean stop() {
        thread.interrupt();
        inputBuffer.clear();
        outputBuffer.clear();
        channel = null;
        state = SessionStateEnum.WAVERIDER_SESSION_STATE_FREE;
        return true;
    }

    @Override
    public void free() {
        control(false);
        inputBuffer.clear();
        outputBuffer.clear();
        closeChannel();
        channel = null;
        state = SessionStateEnum.WAVERIDER_SESSION_STATE_FREE;
    }

    @Override
    public boolean restart() {
        throw new UnsupportedOperationException("Session not supported restart.");
    }

    @Override
    public void execute(Command command) throws ExecuteCommandException {
        try {
            outputBuffer.put(command);
            netWorkServer.notifyWrite(channel);
        } catch (InterruptedException e) {
            logger.error(e);
            Thread.currentThread().interrupt();
            throw new ExecuteCommandException("Interrupted", e);
        }
    }

    //==============================================================
    //            Session state
    //==============================================================
    @Override
    public SlaveWorker getSlaveWorker() {
        return this.slaveWorker;
    }

    @Override
    public SessionStateEnum getState() {
        return this.state;
    }

    @Override
    public SocketChannel getChannel() {
        return channel;
    }

    //==============================================================
    //            Session state transit
    //==============================================================
    @Override
    public void transit() {
        SessionStateEnum oldState = state;
        if (oldState == SessionStateEnum.WAVERIDER_SESSION_STATE_ALIVE) {
            state = SessionStateEnum.WAVERIDER_SESSION_STATE_WAITING_0;
        } else if (oldState == SessionStateEnum.WAVERIDER_SESSION_STATE_WAITING_1) {
            state = SessionStateEnum.WAVERIDER_SESSION_STATE_WAITING_2;
        } else if (oldState == SessionStateEnum.WAVERIDER_SESSION_STATE_WAITING_2) {
            state = SessionStateEnum.WAVERIDER_SESSION_STATE_DEAD;
        }

        /*if (oldState != state) {
           logger.warn(new StringBuilder("Session id = ").append(id).append(" transit state from : ")
           .append(oldState.desc()).append(" to : ").append(state.desc()));
        }*/
    }

    @Override
    public void alive() {
        state = SessionStateEnum.WAVERIDER_SESSION_STATE_ALIVE;
        logger.warn(String.format("Session id = %d is alived.", id));
    }

    @Override
    public boolean isDead() {
        return state == SessionStateEnum.WAVERIDER_SESSION_STATE_DEAD;
    }

    //==============================================================
    //            
    //==============================================================
    @Override
    public void onRead() throws IOException, InterruptedException {
        logger.debug("onRead");
        ByteBuffer buffer = ByteBuffer.allocate(NetWorkConstants.DEFAULT_NETWORK_BUFFER_SIZE);
        int ret = 0;
        do {
            ret = channel.read(buffer);
        } while (ret > 0);

        if (ret == -1) {
            throw new IOException("EOF");
        }
        buffer.flip();
        if (buffer.hasRemaining()) {
            inputBuffer.put(buffer);
            synchronized (waitMoreDataLock) {
                waitMoreDataLock.notifyAll();
            }
        }
        //logger.info("Session is onRead, read " + buffer.remaining() + " bytes");
    }

    @Override
    public void onWrite() throws IOException {
        logger.debug("onWrite");
        int count = 0;
        Command command = null;
        Packet packet = null;
        ByteBuffer data = null;
        while ((command = outputBuffer.poll()) != null) {
            packet = Packet.newDataPacket(command);
            data = packet.marshall();
            count += data.remaining();
            while (data.hasRemaining()) {
                channel.write(data);
            }
            // flush
            channel.socket().getOutputStream().flush();
        }

        //logger.info("Session is onWrite, write " + count + " bytes");
    }

    @Override
    public void onException(Exception e) {
        // TODO
    }

    @Override
    public boolean notifyWrite(SocketChannel channel) {
        logger.debug("notifyWrite");
        return this.netWorkServer.notifyWrite(this.channel);
    }

    @Override
    public boolean notifyRead(SocketChannel channel) {
        logger.debug("notifyRead");
        return this.netWorkServer.notifyRead(this.channel);
    }

    @Override
    public void waitMoreData(long timeout) throws InterruptedException {
        synchronized (waitMoreDataLock) {
            waitMoreDataLock.wait(timeout);
        }
    }

    //==============================================================
    //            DSL
    //==============================================================
    public DefaultSession withNetWorkServer(NetWorkServer netWorkServer) {
        this.netWorkServer = netWorkServer;
        return this;
    }

    public DefaultSession withChannel(SocketChannel channel) {
        this.channel = channel;
        return this;
    }

    public DefaultSession withSlaveWorker(SlaveWorker slaveWorker) {
        this.slaveWorker = slaveWorker;
        return this;
    }

    public DefaultSession withCommandDispatcher(CommandDispatcher commandDispatcher) {
        this.commandDispatcher = commandDispatcher;
        return this;
    }

    //==============================================================
    //            Private internal methods
    //==============================================================

    /**
     * parse network data, convert to command
     */
    private Command _parse_() throws IOException, InterruptedException {
        Command command = CommandFactory.unmarshallCommandFromPacket(Packet.parse(inputBuffer, this, channel));
        command.setSession(this);
        return command;
    }

    /**
     * Session
     * @throws IOException, InterruptedException
     */
    private void _process_() throws IOException, InterruptedException {
        //logger.info("Session try to process one command");
        // ?
        Command command = _parse_();
        if (command != null) {
            // ?Handler
            Command resultCommand = commandDispatcher.dispatch(command);
            command.getPayLoad().clear();
            // (?), ?Slave
            if (resultCommand != null) {
                // Session
                outputBuffer.put(resultCommand);
                // ?, ??
                netWorkServer.notifyWrite(channel);
                //logger.info("Session execute one command");
            }
        }
    }

    /**
     * Session
     * @param isRun
     */
    private void control(boolean isRun) {
        try {
            runLock.lock();
            this.isRun = isRun;
            this.run.signalAll();
        } finally {
            runLock.unlock();
        }
    }

    /**
     * ?
     */
    private void closeChannel() {
        try {
            channel.close();
        } catch (IOException e) {
            logger.error("OOPSException", e);
        }
    }

    //==============================================================
    //            Private thread task
    //            Session
    //==============================================================
    private class SessionTask implements Runnable {

        @Override
        public void run() {
            while (!Thread.currentThread().isInterrupted()) {
                try {
                    try {
                        runLock.lock();
                        while (!isRun) {
                            // idle
                            logger.info(new StringBuilder(Thread.currentThread().getName()).append(" idle"));
                            run.await();
                            logger.info(new StringBuilder(Thread.currentThread().getName()).append(" started"));
                        }
                    } finally {
                        runLock.unlock();
                    }

                    _process_();
                } catch (InterruptedException e) {
                    logger.error("OOPSException", e);
                    e.printStackTrace();
                    Thread.currentThread().interrupt();
                } catch (IOException e) {
                    logger.error("OOPSException", e);
                    e.printStackTrace();
                } catch (Exception e) {
                    logger.error("OOPSException", e);
                    e.printStackTrace();
                } catch (Throwable t) {
                    logger.error("OOPS, Exception:", t);
                } finally {

                }
            }
            logger.info(new StringBuilder(Thread.currentThread().getName()).append(" stoped"));
        }
    }

    //==============================================================
    //            Test
    //==============================================================
    public static void main(String[] args) {

        BlockingQueue<ByteBuffer> inputBuffer = new LinkedBlockingQueue<ByteBuffer>();
        /*for (int i = 0; i < 10; i++)
        {*/
        ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
        byteBuffer.put(makePacket().marshall());
        byteBuffer.put(makePacket().marshall());
        byteBuffer.flip();
        byte[] b = new byte[8];
        ByteBuffer halfBuf0 = ByteBuffer.allocate(8);
        byteBuffer.get(b);
        halfBuf0.put(b);
        halfBuf0.flip();
        inputBuffer.add(halfBuf0);
        inputBuffer.add(byteBuffer);
        /*}*/

        int size = 0;
        int oldSize = size;
        long length = Packet.getHeaderSize();
        ByteBuffer buffer = ByteBuffer.allocate(NetWorkConstants.DEFAULT_NETWORK_BUFFER_SIZE);
        ByteBuffer currentBuffer = null;

        while (size < length) {
            currentBuffer = inputBuffer.peek();
            oldSize = size;
            int position = currentBuffer.position();
            size += currentBuffer.remaining();
            buffer.put(currentBuffer);
            if (size >= Packet.getHeaderSize()) {
                length = buffer.getLong(Packet.getLengthPosition());
            }

            if (size <= length) {
                inputBuffer.remove();
            } else {
                currentBuffer.position(position);
                buffer.position(buffer.position() - currentBuffer.remaining());
                byte[] buf = new byte[(int) (length - oldSize)];
                currentBuffer.get(buf);
                buffer.put(buf);
            }
        }

        // buffer.position(0);
        buffer.flip();
        Packet packet = Packet.unmarshall(buffer);

        Command command = CommandFactory.createCommand(packet.getType(), packet.getPayLoad());

        String str = new String(command.getPayLoad().array());

        System.out.println(str);

    }

    public static Packet makePacket() {
        return Packet.newDataPacket(makeGreetCommand());
    }

    public static Command makeGreetCommand() {
        String HELLO = "Hello Master";
        ByteBuffer buffer = ByteBuffer.allocate(HELLO.getBytes().length);
        buffer.put(HELLO.getBytes());
        buffer.flip();
        Command command = new Command(1L, buffer);
        return command;
    }
}