com.teradata.tempto.internal.context.GuiceTestContext.java Source code

Java tutorial

Introduction

Here is the source code for com.teradata.tempto.internal.context.GuiceTestContext.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.teradata.tempto.internal.context;

import com.google.common.collect.Lists;
import com.google.inject.Binder;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.teradata.tempto.context.State;
import com.teradata.tempto.context.TestContext;
import com.teradata.tempto.context.TestContextCloseCallback;
import org.slf4j.Logger;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.ImmutableList.copyOf;
import static com.google.common.collect.Lists.newArrayList;
import static com.google.common.collect.Maps.newHashMap;
import static com.google.inject.Guice.createInjector;
import static com.google.inject.name.Names.named;
import static com.google.inject.util.Modules.combine;
import static com.google.inject.util.Modules.override;
import static java.util.Collections.synchronizedList;
import static org.assertj.core.util.Lists.emptyList;
import static org.slf4j.LoggerFactory.getLogger;

public class GuiceTestContext implements TestContext {
    private final static Logger LOGGER = getLogger(GuiceTestContext.class);

    private final Optional<GuiceTestContext> parent;
    private final List<GuiceTestContext> children = synchronizedList(newArrayList());
    private final Module baseModule;
    private final Map<Key<State>, State> states;
    private final Injector injector;
    private final List<TestContextCloseCallback> closeCallbacks = newArrayList();

    public GuiceTestContext(Module... baseModules) {
        this(Optional.<GuiceTestContext>empty(), combine(baseModules), newHashMap());
    }

    private GuiceTestContext(Optional<GuiceTestContext> parent, Module baseModule, Map<Key<State>, State> states) {
        this.parent = parent;
        this.baseModule = baseModule;
        this.states = states;
        this.injector = buildInjector();
    }

    @Override
    public <T> T getDependency(Class<T> dependencyClass) {
        return getDependency(Key.get(dependencyClass));
    }

    @Override
    public <T> T getDependency(Class<T> dependencyClass, String dependencyName) {
        return getDependency(Key.get(dependencyClass, named(dependencyName)));
    }

    private <T> T getDependency(Key key) {
        return (T) injector.getInstance(key);
    }

    @Override
    public <T> Optional<T> getOptionalDependency(Class<T> dependencyClass) {
        return getOptionalDependency(Key.get(dependencyClass));
    }

    @Override
    public <T> Optional<T> getOptionalDependency(Class<T> dependencyClass, String dependencyName) {
        return getOptionalDependency(Key.get(dependencyClass, named(dependencyName)));
    }

    public <T> Optional<T> getOptionalDependency(Key key) {
        if (injector.getExistingBinding(key) != null) {
            return Optional.of(getDependency(key));
        } else {
            return Optional.empty();
        }
    }

    @Override
    public GuiceTestContext createChildContext(Iterable<State> newStatesIterable) {
        return createChildContext(newStatesIterable, emptyList());
    }

    public GuiceTestContext createChildContext(Iterable<State> newStatesIterable,
            Iterable<Module> overrideModules) {
        LOGGER.debug("Creating new test context from " + this);

        Map<Key<State>, State> newStates = newHashMap(states);
        for (State newState : newStatesIterable) {
            newStates.put(getKeyFor(newState), newState);
        }

        GuiceTestContext childTestContext = new GuiceTestContext(Optional.of(this),
                override(baseModule).with(overrideModules), newStates);
        children.add(childTestContext);
        return childTestContext;
    }

    @Override
    public void registerCloseCallback(TestContextCloseCallback callback) {
        closeCallbacks.add(callback);
    }

    @Override
    public void close() {
        copyOf(children).forEach(GuiceTestContext::close);
        Lists.reverse(closeCallbacks).forEach(callback -> callback.testContextClosed(this));

        if (parent.isPresent()) {
            parent.get().children.remove(this);
        }
    }

    public void injectMembers(Object instance) {
        injector.injectMembers(instance);
    }

    private Key<State> getKeyFor(State state) {
        if (state.getName().isPresent()) {
            return Key.get((Class<State>) state.getClass(), named(state.getName().get()));
        } else {
            return Key.get((Class<State>) state.getClass());
        }
    }

    private Injector buildInjector() {
        return createInjector(combine(baseModule, statesModule(), testContextModule()));
    }

    private Module statesModule() {
        return (Binder binder) -> {
            for (Map.Entry<Key<State>, State> stateEntry : states.entrySet()) {
                binder.bind(stateEntry.getKey()).toInstance(stateEntry.getValue());
            }
        };
    }

    private Module testContextModule() {
        return (Binder binder) -> {
            binder.bind(TestContext.class).toInstance(this);
        };
    }
}