org.springframework.test.context.jdbc.SqlScriptsTestExecutionListener.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.test.context.jdbc.SqlScriptsTestExecutionListener.java

Source

/*
 * Copyright 2002-2017 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.test.context.jdbc;

import java.lang.reflect.Method;
import java.util.List;
import java.util.Set;
import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
import org.springframework.lang.Nullable;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.test.context.transaction.TestContextTransactionUtils;
import org.springframework.test.context.util.TestContextResourceUtils;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.interceptor.DefaultTransactionAttribute;
import org.springframework.transaction.interceptor.TransactionAttribute;
import org.springframework.transaction.support.TransactionTemplate;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ResourceUtils;
import org.springframework.util.StringUtils;

/**
 * {@code TestExecutionListener} that provides support for executing SQL
 * {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
 * configured via the {@link Sql @Sql} annotation.
 *
 * <p>Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before}
 * or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding
 * {@linkplain java.lang.reflect.Method test method}, depending on the configured
 * value of the {@link Sql#executionPhase executionPhase} flag.
 *
 * <p>Scripts and inlined statements will be executed without a transaction,
 * within an existing Spring-managed transaction, or within an isolated transaction,
 * depending on the configured value of {@link SqlConfig#transactionMode} and the
 * presence of a transaction manager.
 *
 * <h3>Script Resources</h3>
 * <p>For details on default script detection and how script resource locations
 * are interpreted, see {@link Sql#scripts}.
 *
 * <h3>Required Spring Beans</h3>
 * <p>A {@link PlatformTransactionManager} <em>and</em> a {@link DataSource},
 * just a {@link PlatformTransactionManager}, or just a {@link DataSource}
 * must be defined as beans in the Spring {@link ApplicationContext} for the
 * corresponding test. Consult the javadocs for {@link SqlConfig#transactionMode},
 * {@link SqlConfig#transactionManager}, {@link SqlConfig#dataSource},
 * {@link TestContextTransactionUtils#retrieveDataSource}, and
 * {@link TestContextTransactionUtils#retrieveTransactionManager} for details
 * on permissible configuration constellations and on the algorithms used to
 * locate these beans.
 *
 * @author Sam Brannen
 * @since 4.1
 * @see Sql
 * @see SqlConfig
 * @see SqlGroup
 * @see org.springframework.test.context.transaction.TestContextTransactionUtils
 * @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
 * @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
 * @see org.springframework.jdbc.datasource.init.ScriptUtils
 */
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {

    private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);

    /**
     * Returns {@code 5000}.
     */
    @Override
    public final int getOrder() {
        return 5000;
    }

    /**
     * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
     * {@link TestContext} <em>before</em> the current test method.
     */
    @Override
    public void beforeTestMethod(TestContext testContext) throws Exception {
        executeSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_METHOD);
    }

    /**
     * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
     * {@link TestContext} <em>after</em> the current test method.
     */
    @Override
    public void afterTestMethod(TestContext testContext) throws Exception {
        executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD);
    }

    /**
     * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
     * {@link TestContext} and {@link ExecutionPhase}.
     */
    private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
        boolean classLevel = false;

        Set<Sql> sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(testContext.getTestMethod(),
                Sql.class, SqlGroup.class);
        if (sqlAnnotations.isEmpty()) {
            sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(testContext.getTestClass(),
                    Sql.class, SqlGroup.class);
            if (!sqlAnnotations.isEmpty()) {
                classLevel = true;
            }
        }

        for (Sql sql : sqlAnnotations) {
            executeSqlScripts(sql, executionPhase, testContext, classLevel);
        }
    }

    /**
     * Execute the SQL scripts configured via the supplied {@link Sql @Sql}
     * annotation for the given {@link ExecutionPhase} and {@link TestContext}.
     * <p>Special care must be taken in order to properly support the configured
     * {@link SqlConfig#transactionMode}.
     * @param sql the {@code @Sql} annotation to parse
     * @param executionPhase the current execution phase
     * @param testContext the current {@code TestContext}
     * @param classLevel {@code true} if {@link Sql @Sql} was declared at the class level
     */
    private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestContext testContext,
            boolean classLevel) throws Exception {

        if (executionPhase != sql.executionPhase()) {
            return;
        }

        MergedSqlConfig mergedSqlConfig = new MergedSqlConfig(sql.config(), testContext.getTestClass());
        if (logger.isDebugEnabled()) {
            logger.debug(String.format("Processing %s for execution phase [%s] and test context %s.",
                    mergedSqlConfig, executionPhase, testContext));
        }

        final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
        populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
        populator.setSeparator(mergedSqlConfig.getSeparator());
        populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
        populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
        populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
        populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
        populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);

        String[] scripts = getScripts(sql, testContext, classLevel);
        scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
        List<Resource> scriptResources = TestContextResourceUtils
                .convertToResourceList(testContext.getApplicationContext(), scripts);
        for (String stmt : sql.statements()) {
            if (StringUtils.hasText(stmt)) {
                stmt = stmt.trim();
                scriptResources.add(new ByteArrayResource(stmt.getBytes(), "from inlined SQL statement: " + stmt));
            }
        }
        populator.setScripts(scriptResources.toArray(new Resource[scriptResources.size()]));
        if (logger.isDebugEnabled()) {
            logger.debug("Executing SQL scripts: " + ObjectUtils.nullSafeToString(scriptResources));
        }

        String dsName = mergedSqlConfig.getDataSource();
        String tmName = mergedSqlConfig.getTransactionManager();
        DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, dsName);
        PlatformTransactionManager txMgr = TestContextTransactionUtils.retrieveTransactionManager(testContext,
                tmName);
        boolean newTxRequired = (mergedSqlConfig.getTransactionMode() == TransactionMode.ISOLATED);

        if (txMgr == null) {
            Assert.state(!newTxRequired,
                    () -> String.format(
                            "Failed to execute SQL scripts for test context %s: "
                                    + "cannot execute SQL scripts using Transaction Mode "
                                    + "[%s] without a PlatformTransactionManager.",
                            testContext, TransactionMode.ISOLATED));
            Assert.state(dataSource != null,
                    () -> String.format("Failed to execute SQL scripts for test context %s: "
                            + "supply at least a DataSource or PlatformTransactionManager.", testContext));
            // Execute scripts directly against the DataSource
            populator.execute(dataSource);
        } else {
            DataSource dataSourceFromTxMgr = getDataSourceFromTransactionManager(txMgr);
            // Ensure user configured an appropriate DataSource/TransactionManager pair.
            if (dataSource != null && dataSourceFromTxMgr != null && !dataSource.equals(dataSourceFromTxMgr)) {
                throw new IllegalStateException(String.format(
                        "Failed to execute SQL scripts for test context %s: "
                                + "the configured DataSource [%s] (named '%s') is not the one associated with "
                                + "transaction manager [%s] (named '%s').",
                        testContext, dataSource.getClass().getName(), dsName, txMgr.getClass().getName(), tmName));
            }
            if (dataSource == null) {
                dataSource = dataSourceFromTxMgr;
                Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for "
                        + "test context %s: could not obtain DataSource from transaction manager [%s] (named '%s').",
                        testContext, txMgr.getClass().getName(), tmName));
            }
            final DataSource finalDataSource = dataSource;
            int propagation = (newTxRequired ? TransactionDefinition.PROPAGATION_REQUIRES_NEW
                    : TransactionDefinition.PROPAGATION_REQUIRED);
            TransactionAttribute txAttr = TestContextTransactionUtils.createDelegatingTransactionAttribute(
                    testContext, new DefaultTransactionAttribute(propagation));
            new TransactionTemplate(txMgr, txAttr).execute(status -> {
                populator.execute(finalDataSource);
                return null;
            });
        }
    }

    @Nullable
    private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
        try {
            Method getDataSourceMethod = transactionManager.getClass().getMethod("getDataSource");
            Object obj = ReflectionUtils.invokeMethod(getDataSourceMethod, transactionManager);
            if (obj instanceof DataSource) {
                return (DataSource) obj;
            }
        } catch (Exception ex) {
            // ignore
        }
        return null;
    }

    private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) {
        String[] scripts = sql.scripts();
        if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
            scripts = new String[] { detectDefaultScript(testContext, classLevel) };
        }
        return scripts;
    }

    /**
     * Detect a default SQL script by implementing the algorithm defined in
     * {@link Sql#scripts}.
     */
    private String detectDefaultScript(TestContext testContext, boolean classLevel) {
        Class<?> clazz = testContext.getTestClass();
        Method method = testContext.getTestMethod();
        String elementType = (classLevel ? "class" : "method");
        String elementName = (classLevel ? clazz.getName() : method.toString());

        String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName());
        if (!classLevel) {
            resourcePath += "." + method.getName();
        }
        resourcePath += ".sql";

        String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + resourcePath;
        ClassPathResource classPathResource = new ClassPathResource(resourcePath);

        if (classPathResource.exists()) {
            if (logger.isInfoEnabled()) {
                logger.info(String.format("Detected default SQL script \"%s\" for test %s [%s]",
                        prefixedResourcePath, elementType, elementName));
            }
            return prefixedResourcePath;
        } else {
            String msg = String.format("Could not detect default SQL script for test %s [%s]: "
                    + "%s does not exist. Either declare statements or scripts via @Sql or make the "
                    + "default SQL script available.", elementType, elementName, classPathResource);
            logger.error(msg);
            throw new IllegalStateException(msg);
        }
    }

}