com.crudetech.junit.categories.Categories.java Source code

Java tutorial

Introduction

Here is the source code for com.crudetech.junit.categories.Categories.java

Source

////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2011, Andreas Mueller.
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// which accompanies this distribution, and is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// Contributors:
//      Andreas Mueller - initial API and implementation
////////////////////////////////////////////////////////////////////////////////
package com.crudetech.junit.categories;

import org.junit.runner.Description;
import org.junit.runner.manipulation.Filter;
import org.junit.runner.manipulation.NoTestsRemainException;
import org.junit.runners.Suite;
import org.junit.runners.model.InitializationError;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;

import java.lang.reflect.Modifier;
import java.util.*;

import static com.crudetech.junit.categories.If.isNull;
import static com.crudetech.junit.categories.If.isNullOrEmpty;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;

/**
 * A suite runner that collects all tests from the class path that match the
 * pattern specified in the {@link TestNamePattern} annotation;
 */
public class Categories extends Suite {

    public Categories(Class<?> testClass) throws InitializationError {
        super(testClass, allTestClassesInClassPathMatchingPattern(testClass));

        try {
            filter(new CategoryFilter(getInclusions(testClass), getExclusions(testClass)));
        } catch (NoTestsRemainException e) {
            throw new InitializationError(e);
        }
    }

    private static Collection<Class<?>> getExclusions(Class<?> testClass) {
        ExcludeCategory exc = testClass.getAnnotation(ExcludeCategory.class);
        if (isNull(exc) || isNullOrEmpty(exc.value())) {
            return emptyList();
        }
        return asList(exc.value());
    }

    private static final Class<?> All = Void.class;

    private static Collection<Class<?>> getInclusions(Class<?> testClass) throws InitializationError {
        IncludeCategory inc = testClass.getAnnotation(IncludeCategory.class);
        if (isNull(inc) || isNullOrEmpty(inc.value())) {
            return Arrays.<Class<?>>asList(All);
        }
        return asList(inc.value());
    }

    static class CategoryFilter extends Filter {

        private final Collection<Class<?>> inclusions;
        private final Collection<Class<?>> exclusions;

        public CategoryFilter(Collection<Class<?>> inclusions, Collection<Class<?>> exclusions) {
            this.inclusions = inclusions;
            this.exclusions = exclusions;
        }

        @Override
        public boolean shouldRun(Description description) {
            // code for categories on methods
            for (Description child : description.getChildren()) {
                if (shouldRun(child)) {
                    return true;
                }
            }

            Collection<Class<?>> categories = getCategories(description);

            if (allCategoriesAreIncluded()) {
                return categoriesAreNotExcluded(categories);
            }

            return categoriesAreNotExcluded(categories) && categoriesAreIncluded(categories);

        }

        private boolean categoriesAreIncluded(Collection<Class<?>> categories) {
            return atLeastOneIsIncluded(categories, inclusions);
        }

        private boolean categoriesAreExcluded(Collection<Class<?>> categories) {
            return atLeastOneIsIncluded(categories, exclusions);
        }

        private static boolean atLeastOneIsIncluded(Collection<Class<?>> in, Collection<Class<?>> of) {
            for (Class<?> ofItem : of) {
                // code for inheriting categories
                //                for(Class<?> inItem : in){
                //                    if(inItem.isAssignableFrom(ofItem)){
                //                        return true;
                //                    }
                //                }
                if (in.contains(ofItem)) {
                    return true;
                }
            }
            return false;
        }

        private boolean categoriesAreNotExcluded(Collection<Class<?>> categories) {
            return !categoriesAreExcluded(categories);
        }

        private boolean allCategoriesAreIncluded() {
            return inclusions.contains(All);
        }

        private Set<Class<?>> getCategories(Description description) {
            Set<Class<?>> categories = getCategories(description.getAnnotation(Category.class));
            if (description.getTestClass() != null) {
                categories.addAll(getCategories(description.getTestClass()));
            }
            return categories;
        }

        private Set<Class<?>> getCategories(Class<?> clazz) {
            if (clazz == null) {
                return Collections.emptySet();
            }
            Set<Class<?>> categories = getCategories(clazz.getAnnotation(Category.class));
            categories.addAll(getCategories(clazz.getSuperclass()));
            for (Class<?> i : clazz.getInterfaces()) {
                categories.addAll(getCategories(i));
            }
            return categories;
        }

        private Set<Class<?>> getCategories(Category cat) {
            if (isNull(cat)) {
                return new HashSet<Class<?>>(Arrays.<Class<?>>asList(StandardCategory.class));
            }
            if (isNull(cat.value())) {
                return new HashSet<Class<?>>();
            }
            return new HashSet<Class<?>>(asList(cat.value()));
        }

        @Override
        public String describe() {
            return "category";
        }
    }

    static Class<?>[] allTestClassesInClassPathMatchingPattern(String pattern) throws InitializationError {
        List<Class<?>> classes = new ArrayList<Class<?>>();

        PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        MetadataReaderFactory metaDataReaderFactory = new CachingMetadataReaderFactory();
        try {
            String classPattern = "classpath:" + pattern.replace('.', '/') + ".class";
            Resource[] res = resolver.getResources(classPattern);
            for (Resource r : res) {
                if (!r.isReadable()) {
                    continue;
                }
                MetadataReader reader = metaDataReaderFactory.getMetadataReader(r);
                Class<?> c = Class.forName(reader.getClassMetadata().getClassName());

                if (Modifier.isAbstract(c.getModifiers())) {
                    continue;
                }
                classes.add(c);
            }
            return classes.toArray(new Class<?>[classes.size()]);
        } catch (Exception e) {
            throw new InitializationError(e);
        }
    }

    private static Class<?>[] allTestClassesInClassPathMatchingPattern(Class<?> testClass)
            throws InitializationError {
        return allTestClassesInClassPathMatchingPattern(getPatternFrom(testClass));
    }

    private static String getPatternFrom(Class<?> clazz) throws InitializationError {
        TestNamePattern pattern = clazz.getAnnotation(TestNamePattern.class);
        if (isNull(pattern) || isNullOrEmpty(pattern.value())) {
            throw new InitializationError("No proper test name pattern specified!");
        }
        return pattern.value();
    }
}