org.springframework.batch.core.test.step.FaultTolerantStepFactoryBeanIntegrationTests.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.batch.core.test.step.FaultTolerantStepFactoryBeanIntegrationTests.java

Source

/*
 * Copyright 2010-2014 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.batch.core.test.step;

import static org.junit.Assert.assertEquals;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.JobParameters;
import org.springframework.batch.core.Step;
import org.springframework.batch.core.StepExecution;
import org.springframework.batch.core.repository.JobRepository;
import org.springframework.batch.core.step.factory.FaultTolerantStepFactoryBean;
import org.springframework.batch.item.ItemProcessor;
import org.springframework.batch.item.ItemReader;
import org.springframework.batch.item.ItemWriter;
import org.springframework.batch.item.ParseException;
import org.springframework.batch.item.UnexpectedInputException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.jdbc.JdbcTestUtils;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.util.Assert;

/**
 * Tests for {@link FaultTolerantStepFactoryBean}.
 */
@ContextConfiguration(locations = "/simple-job-launcher-context.xml")
@RunWith(SpringJUnit4ClassRunner.class)
public class FaultTolerantStepFactoryBeanIntegrationTests {

    private static final int MAX_COUNT = 1000;

    private final Log logger = LogFactory.getLog(getClass());

    private FaultTolerantStepFactoryBean<String, String> factory;

    private SkipProcessorStub processor;

    private SkipWriterStub writer;

    private JobExecution jobExecution;

    private StepExecution stepExecution;

    @Autowired
    private DataSource dataSource;

    @Autowired
    private JobRepository repository;

    @Autowired
    private PlatformTransactionManager transactionManager;

    @Before
    public void setUp() throws Exception {

        writer = new SkipWriterStub(dataSource);
        processor = new SkipProcessorStub(dataSource);

        factory = new FaultTolerantStepFactoryBean<String, String>();

        factory.setBeanName("stepName");
        factory.setTransactionManager(transactionManager);
        factory.setJobRepository(repository);
        factory.setCommitInterval(3);
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        taskExecutor.setCorePoolSize(3);
        taskExecutor.setMaxPoolSize(6);
        taskExecutor.setQueueCapacity(0);
        taskExecutor.afterPropertiesSet();
        factory.setTaskExecutor(taskExecutor);

        JdbcTestUtils.deleteFromTables(new JdbcTemplate(dataSource), "ERROR_LOG");

    }

    @Test
    public void testUpdatesNoRollback() throws Exception {

        JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);

        writer.write(Arrays.asList("foo", "bar"));
        processor.process("spam");
        assertEquals(3, JdbcTestUtils.countRowsInTable(jdbcTemplate, "ERROR_LOG"));

        writer.clear();
        processor.clear();
        assertEquals(0, JdbcTestUtils.countRowsInTable(jdbcTemplate, "ERROR_LOG"));

    }

    @Test
    public void testMultithreadedSunnyDay() throws Throwable {

        jobExecution = repository.createJobExecution("vanillaJob", new JobParameters());

        for (int i = 0; i < MAX_COUNT; i++) {

            JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);

            SkipReaderStub reader = new SkipReaderStub();
            reader.clear();
            reader.setItems("1", "2", "3", "4", "5");
            factory.setItemReader(reader);
            writer.clear();
            factory.setItemWriter(writer);
            processor.clear();
            factory.setItemProcessor(processor);

            assertEquals(0, JdbcTestUtils.countRowsInTable(jdbcTemplate, "ERROR_LOG"));

            try {

                Step step = factory.getObject();

                stepExecution = jobExecution.createStepExecution(factory.getName());
                repository.add(stepExecution);
                step.execute(stepExecution);
                assertEquals(BatchStatus.COMPLETED, stepExecution.getStatus());

                List<String> committed = new ArrayList<String>(writer.getCommitted());
                Collections.sort(committed);
                assertEquals("[1, 2, 3, 4, 5]", committed.toString());
                List<String> processed = new ArrayList<String>(processor.getCommitted());
                Collections.sort(processed);
                assertEquals("[1, 2, 3, 4, 5]", processed.toString());
                assertEquals(0, stepExecution.getSkipCount());

            } catch (Throwable e) {
                logger.info("Failed on iteration " + i + " of " + MAX_COUNT);
                throw e;
            }

        }

    }

    private static class SkipReaderStub implements ItemReader<String> {

        private String[] items;

        private int counter = -1;

        public SkipReaderStub() throws Exception {
            super();
        }

        public void setItems(String... items) {
            Assert.isTrue(counter < 0, "Items cannot be set once reading has started");
            this.items = items;
        }

        public void clear() {
            counter = -1;
        }

        @Override
        public synchronized String read() throws Exception, UnexpectedInputException, ParseException {
            counter++;
            if (counter >= items.length) {
                return null;
            }
            String item = items[counter];
            return item;
        }
    }

    private static class SkipWriterStub implements ItemWriter<String> {

        private List<String> written = new ArrayList<String>();

        private Collection<String> failures = Collections.emptySet();

        private JdbcTemplate jdbcTemplate;

        public SkipWriterStub(DataSource dataSource) {
            jdbcTemplate = new JdbcTemplate(dataSource);
        }

        public List<String> getCommitted() {
            return jdbcTemplate.query("SELECT MESSAGE from ERROR_LOG where STEP_NAME='written'",
                    new RowMapper<String>() {
                        @Override
                        public String mapRow(ResultSet rs, int rowNum) throws SQLException {
                            return rs.getString(1);
                        }
                    });
        }

        public void clear() {
            written.clear();
            jdbcTemplate.update("DELETE FROM ERROR_LOG where STEP_NAME='written'");
        }

        @Override
        public void write(List<? extends String> items) throws Exception {
            for (String item : items) {
                written.add(item);
                jdbcTemplate.update("INSERT INTO ERROR_LOG (MESSAGE, STEP_NAME) VALUES (?, ?)", item, "written");
                checkFailure(item);
            }
        }

        private void checkFailure(String item) {
            if (failures.contains(item)) {
                throw new RuntimeException("Planned failure");
            }
        }
    }

    private static class SkipProcessorStub implements ItemProcessor<String, String> {

        private final Log logger = LogFactory.getLog(getClass());

        private List<String> processed = new ArrayList<String>();

        private JdbcTemplate jdbcTemplate;

        /**
         * @param dataSource
         */
        public SkipProcessorStub(DataSource dataSource) {
            jdbcTemplate = new JdbcTemplate(dataSource);
        }

        public List<String> getCommitted() {
            return jdbcTemplate.query("SELECT MESSAGE from ERROR_LOG where STEP_NAME='processed'",
                    new RowMapper<String>() {
                        @Override
                        public String mapRow(ResultSet rs, int rowNum) throws SQLException {
                            return rs.getString(1);
                        }
                    });
        }

        public void clear() {
            processed.clear();
            jdbcTemplate.update("DELETE FROM ERROR_LOG where STEP_NAME='processed'");
        }

        @Override
        public String process(String item) throws Exception {
            processed.add(item);
            logger.debug("Processed item: " + item);
            jdbcTemplate.update("INSERT INTO ERROR_LOG (MESSAGE, STEP_NAME) VALUES (?, ?)", item, "processed");
            return item;
        }
    }

}