rapture.repo.jdbc.TransactionAwareDataSource.java Source code

Java tutorial

Introduction

Here is the source code for rapture.repo.jdbc.TransactionAwareDataSource.java

Source

/**
 * Copyright (C) 2011-2015 Incapture Technologies LLC
 *
 * This is an autogenerated license statement. When copyright notices appear below
 * this one that copyright supercedes this statement.
 *
 * Unless required by applicable law or agreed to in writing, software is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied.
 *
 * Unless explicit permission obtained in writing this software cannot be distributed.
 */
package rapture.repo.jdbc;

import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.logging.Logger;

import javax.sql.DataSource;

import org.apache.commons.lang3.StringUtils;
import org.springframework.jdbc.datasource.SmartDataSource;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Maps;

import rapture.kernel.TransactionManager;

/**
 * Created by yanwang on 4/20/15.
 */
public class TransactionAwareDataSource implements SmartDataSource {

    private DataSource underlyingDataSource;
    private BiMap<String, Connection> connections; // txId -> Connection

    public TransactionAwareDataSource(DataSource dataSource) {
        this.underlyingDataSource = dataSource;
        connections = Maps.synchronizedBiMap(HashBiMap.<String, Connection>create());
    }

    @Override
    public Connection getConnection() throws SQLException {
        String txId = TransactionManager.getActiveTransaction();
        if (StringUtils.isBlank(txId)) { // no active transaction
            return getNonTransactionalConnection();
        } else {
            return getTransactionalConnection(txId);
        }
    }

    @Override
    public Connection getConnection(String username, String password) throws SQLException {
        String txId = TransactionManager.getActiveTransaction();
        if (StringUtils.isBlank(txId)) { // no active transaction
            return getNonTransactionalConnection();
        } else {
            return getTransactionalConnection(txId, username, password);
        }
    }

    private Connection getTransactionalConnection(String txId, String... args) throws SQLException {
        Connection connection = connections.get(txId);
        if (connection == null) {
            if (args.length == 2) {
                connection = underlyingDataSource.getConnection(args[0], args[1]);
            } else {
                connection = underlyingDataSource.getConnection();
            }
            connection.setAutoCommit(false);
            connections.put(txId, connection);
        }
        return connection;
    }

    private Connection getNonTransactionalConnection() throws SQLException {
        return underlyingDataSource.getConnection();
    }

    public void commit(String txId) throws SQLException {
        Connection connection = connections.get(txId);
        connection.commit();
        releaseConnection(connection);
    }

    public void rollback(String txId) throws SQLException {
        Connection connection = connections.get(txId);
        connection.rollback();
        releaseConnection(connection);
    }

    @Override
    public boolean shouldClose(Connection connection) {
        return !connections.containsValue(connection);
    }

    private void releaseConnection(Connection connection) throws SQLException {
        connections.values().remove(connection);
        //TODO reset other attributes of a connection, eg, isolation level etc
        connection.setAutoCommit(true);
        connection.close();
    }

    @Override
    public PrintWriter getLogWriter() throws SQLException {
        return underlyingDataSource.getLogWriter();
    }

    @Override
    public void setLogWriter(PrintWriter out) throws SQLException {
        underlyingDataSource.setLogWriter(out);
    }

    @Override
    public void setLoginTimeout(int seconds) throws SQLException {
        underlyingDataSource.setLoginTimeout(seconds);
    }

    @Override
    public int getLoginTimeout() throws SQLException {
        return underlyingDataSource.getLoginTimeout();
    }

    @Override
    public Logger getParentLogger() throws SQLFeatureNotSupportedException {
        return underlyingDataSource.getParentLogger();
    }

    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return underlyingDataSource.unwrap(iface);
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return underlyingDataSource.isWrapperFor(iface);
    }

}