Java tutorial
/** * Copyright (C) 2009 kiy0taka.org * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.kiy0taka.dbunit; import static org.kiy0taka.dbunit.DataSetBuilder.dataSet; import java.io.FileNotFoundException; import java.io.IOException; import java.lang.reflect.Field; import java.net.URL; import java.security.AccessController; import java.security.PrivilegedAction; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.MissingResourceException; import java.util.Properties; import java.util.PropertyResourceBundle; import java.util.ResourceBundle; import java.util.Set; import javax.sql.DataSource; import org.apache.commons.dbcp.BasicDataSource; import org.dbunit.Assertion; import org.dbunit.DatabaseUnitException; import org.dbunit.database.DatabaseConfig; import org.dbunit.database.DatabaseConfig.ConfigProperty; import org.dbunit.database.DatabaseDataSourceConnection; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.DataSetException; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.excel.XlsDataSet; import org.dbunit.dataset.xml.FlatXmlDataSet; import org.dbunit.dataset.xml.FlatXmlProducer; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkField; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; import org.xml.sax.InputSource; /** * JUnit Runner implementation for DbUnit. * @author kiy0taka */ public class DbUnitRunner extends BlockJUnit4ClassRunner { private static final ResourceBundle BUNDLE; static { BUNDLE = PropertyResourceBundle.getBundle("dbunit-runner"); loadDriver(BUNDLE.getString("driver")); } protected static void loadDriver(String driverName) { try { Class.forName(driverName); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } } private enum DataSetType { xml() { public IDataSet createDataSet(URL url) throws DataSetException, IOException { return new FlatXmlDataSet(new FlatXmlProducer(new InputSource(url.openStream()))); } }, xls() { public IDataSet createDataSet(URL url) throws DataSetException, IOException { return new XlsDataSet(url.openStream()); } }; public abstract IDataSet createDataSet(URL url) throws DataSetException, IOException; } protected DataSource dataSource; protected Connection testConnection; protected String jdbcUrl = BUNDLE.getString("url"); protected String username = BUNDLE.getString("username"); protected String password = BUNDLE.getString("password"); protected String schema = optionalValue(BUNDLE, "schema"); protected Properties configProperties = new Properties(); /** * Constract Runner for DbUnit. * @param testClass Test Class * @throws InitializationError Initialization error */ public DbUnitRunner(Class<?> testClass) throws InitializationError { super(testClass); for (ConfigProperty cp : DatabaseConfig.ALL_PROPERTIES) { try { configProperties.put(cp.getProperty(), BUNDLE.getString(cp.getProperty())); } catch (MissingResourceException ignore) { // NOP } } } protected Statement methodBlock(final FrameworkMethod method) { Statement stmt = super.methodBlock(method); DbUnitTest ann = method.getAnnotation(DbUnitTest.class); return ann == null ? stmt : new DbUnitStatement(ann, stmt); } protected List<FrameworkMethod> computeTestMethods() { Set<FrameworkMethod> set = new HashSet<FrameworkMethod>(super.computeTestMethods()); set.addAll(getTestClass().getAnnotatedMethods(DbUnitTest.class)); return new ArrayList<FrameworkMethod>(set); } protected Object createTest() throws Exception { Object result = super.createTest(); dataSource = createDataSource(); List<FrameworkField> connFields = getTestClass().getAnnotatedFields(TestConnection.class); if (!connFields.isEmpty()) { testConnection = dataSource.getConnection(); for (FrameworkField ff : connFields) { final Field f = ff.getField(); AccessController.doPrivileged(new SetAccessibleAction(f)); f.set(result, testConnection); } } List<FrameworkField> dsFields = getTestClass().getAnnotatedFields(TestDataSource.class); if (!dsFields.isEmpty()) { for (FrameworkField ff : dsFields) { final Field f = ff.getField(); AccessController.doPrivileged(new SetAccessibleAction(f)); f.set(result, dataSource); } } return result; } protected DataSource createDataSource() { BasicDataSource result = new BasicDataSource(); result.setUsername(username); result.setPassword(password); result.setUrl(jdbcUrl); return result; } protected static String optionalValue(ResourceBundle bundle, String key) { try { return bundle.getString(key); } catch (MissingResourceException ignore) { return null; } } private static class SetAccessibleAction implements PrivilegedAction<Object> { private Field field; public SetAccessibleAction(Field field) { this.field = field; } public Object run() { field.setAccessible(true); return null; } } protected class DbUnitStatement extends Statement { private DbUnitTest ann; private Statement statement; protected DbUnitStatement(DbUnitTest ann, Statement statement) { this.ann = ann; this.statement = statement; } public void evaluate() throws Throwable { IDatabaseConnection conn = createDatabaseConnection(); try { executeUpdate(conn, ann.sql()); IDataSet initData = dataSet(load(ann.init())).nullValue(ann.nullValue()).toDataSet(); ann.operation().toDatabaseOperation().execute(conn, initData); statement.evaluate(); if (testConnection != null) { testConnection.commit(); } } catch (Throwable e) { if (testConnection != null) { testConnection.rollback(); } throw e; } finally { if (testConnection != null) { testConnection.close(); } conn.close(); } if (!ann.expected().isEmpty()) { assertTables(); } } protected void assertTables() { IDatabaseConnection conn = createDatabaseConnection(); try { IDataSet expected = dataSet(load(ann.expected())).excludeColumns(ann.excludeColumns()) .nullValue(ann.nullValue()).rtrim(ann.rtrim()).toDataSet(); IDataSet actual = dataSet(conn.createDataSet(expected.getTableNames())) .excludeColumns(ann.excludeColumns()).rtrim(ann.rtrim()).toDataSet(); Assertion.assertEquals(expected, actual); } catch (SQLException e) { throw new RuntimeException(e); } catch (DatabaseUnitException e) { throw new RuntimeException(e); } finally { try { conn.close(); } catch (SQLException e) { throw new RuntimeException(e); } } } protected IDataSet load(String path) { URL url = getTestClass().getJavaClass().getResource(path); if (url == null) { throw new RuntimeException(new FileNotFoundException(path)); } String suffix = path.substring(path.lastIndexOf('.') + 1).toLowerCase(Locale.getDefault()); try { return DataSetType.valueOf(suffix).createDataSet(url); } catch (Exception e) { throw new RuntimeException(e); } } protected IDatabaseConnection createDatabaseConnection() { try { DatabaseDataSourceConnection result = new DatabaseDataSourceConnection(dataSource, schema); DatabaseConfig config = result.getConfig(); config.setPropertiesByString(configProperties); return result; } catch (SQLException e) { throw new RuntimeException(e); } catch (DatabaseUnitException e) { throw new RuntimeException(e); } } protected void executeUpdate(IDatabaseConnection conn, String... sql) throws SQLException { for (String s : sql) { if (s.isEmpty()) { continue; } PreparedStatement stmt = null; try { stmt = conn.getConnection().prepareStatement(s); stmt.executeUpdate(); } finally { if (stmt != null) { stmt.close(); } } } } } }