io.seldon.api.state.PredictionAlgorithmStore.java Source code

Java tutorial

Introduction

Here is the source code for io.seldon.api.state.PredictionAlgorithmStore.java

Source

/*
 * Seldon -- open source prediction engine
 * =======================================
 * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
 *
 **********************************************************************************************
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at       
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 ********************************************************************************************** 
*/
package io.seldon.api.state;

import io.seldon.prediction.FeatureTransformer;
import io.seldon.prediction.FeatureTransformerStrategy;
import io.seldon.prediction.PredictionAlgorithm;
import io.seldon.prediction.PredictionAlgorithmStrategy;
import io.seldon.prediction.PredictionStrategy;
import io.seldon.prediction.SimplePredictionStrategy;
import io.seldon.prediction.VariationPredictionStrategy;

import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import javax.annotation.PostConstruct;

import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;

@Component
public class PredictionAlgorithmStore
        implements ApplicationContextAware, ClientConfigUpdateListener, GlobalConfigUpdateListener {
    protected static Logger logger = Logger.getLogger(PredictionAlgorithmStore.class.getName());
    public static final String ALG_KEY = "predict_algs";

    private ConcurrentMap<String, PredictionStrategy> predictionStore = new ConcurrentHashMap<>();
    private ConcurrentMap<String, FeatureTransformerStrategy> transformerStore = new ConcurrentHashMap<>();
    private PredictionStrategy defaultStrategy;

    private final ClientConfigHandler configHandler;
    private final GlobalConfigHandler globalConfigHandler;
    private ApplicationContext applicationContext;

    @Autowired
    public PredictionAlgorithmStore(ClientConfigHandler configHandler, GlobalConfigHandler globalConfigHandler) {
        this.configHandler = configHandler;
        this.globalConfigHandler = globalConfigHandler;
    }

    @PostConstruct
    private void init() {
        logger.info("Initializing...");
        configHandler.addListener(this);
        globalConfigHandler.addSubscriber("default_prediction_strategy", this);
    }

    public PredictionStrategy retrieveStrategy(String client) {
        PredictionStrategy strategy = predictionStore.get(client);
        if (strategy != null) {
            return strategy;
        } else {
            return defaultStrategy;
        }
    }

    @Override
    public void configUpdated(String configKey, String configValue) {
        configValue = StringUtils.strip(configValue);
        logger.info("KEY WAS " + configKey);
        logger.info("Received new default strategy: " + configValue);

        if (StringUtils.length(configValue) == 0) {
            logger.warn("*WARNING* no default strategy is set!");
        } else {
            try {
                ObjectMapper mapper = new ObjectMapper();
                List<PredictionAlgorithmStrategy> strategies = new ArrayList<>();
                AlgorithmConfig config = mapper.readValue(configValue, AlgorithmConfig.class);
                for (Algorithm algorithm : config.algorithms) {
                    PredictionAlgorithmStrategy strategy = toAlgorithmStrategy(algorithm);
                    strategies.add(strategy);
                }
                List<FeatureTransformerStrategy> featureTransformerStrategies = new ArrayList<>();
                for (Transformer transformer : config.transformers) {
                    FeatureTransformerStrategy strategy = toFeatureTransformerStrategy(transformer);
                    featureTransformerStrategies.add(strategy);
                }
                defaultStrategy = new SimplePredictionStrategy(PredictionStrategy.DEFAULT_NAME,
                        Collections.unmodifiableList(featureTransformerStrategies),
                        Collections.unmodifiableList(strategies));
                logger.info("Successfully added new default prediction strategy");
            } catch (IOException | BeansException e) {
                logger.error("Couldn't update default prediction strategy", e);
            }
        }
    }

    @Override
    public void configUpdated(String client, String configKey, String configValue) {
        SimpleModule module = new SimpleModule("PredictionStrategyDeserializerModule");
        module.addDeserializer(Strategy.class, new PredictionStrategyDeserializer());
        ObjectMapper mapper = new ObjectMapper();
        mapper.registerModule(module);
        if (configKey.equals(ALG_KEY)) {
            logger.info("Received new algorithm config for " + client + ": " + configValue);
            try {
                Strategy configStrategy = mapper.readValue(configValue, Strategy.class);
                if (configStrategy instanceof AlgorithmConfig) {
                    List<PredictionAlgorithmStrategy> strategies = new ArrayList<>();
                    AlgorithmConfig config = (AlgorithmConfig) configStrategy;
                    for (Algorithm algorithm : config.algorithms) {
                        PredictionAlgorithmStrategy strategy = toAlgorithmStrategy(algorithm);
                        strategies.add(strategy);
                    }
                    List<FeatureTransformerStrategy> featureTransformerStrategies = new ArrayList<>();
                    if (config.transformers != null)
                        for (Transformer transformer : config.transformers) {
                            FeatureTransformerStrategy strategy = toFeatureTransformerStrategy(transformer);
                            featureTransformerStrategies.add(strategy);
                        }
                    predictionStore.put(client,
                            new SimplePredictionStrategy(PredictionStrategy.DEFAULT_NAME,
                                    Collections.unmodifiableList(featureTransformerStrategies),
                                    Collections.unmodifiableList(strategies)));
                    logger.info("Successfully added new algorithm config for " + client);
                } else if (configStrategy instanceof TestConfig) {
                    TestConfig config = (TestConfig) configStrategy;
                    //TestConfig config = mapper.readValue(configValue, TestConfig.class);

                    List<VariationPredictionStrategy.Variation> variations = new ArrayList<>();
                    for (TestVariation var : config.variations) {
                        List<PredictionAlgorithmStrategy> strategies = new ArrayList<>();
                        for (Algorithm alg : var.config.algorithms) {
                            PredictionAlgorithmStrategy strategy = toAlgorithmStrategy(alg);
                            strategies.add(strategy);
                        }
                        List<FeatureTransformerStrategy> featureTransformerStrategies = new ArrayList<>();
                        if (var.config.transformers != null)
                            for (Transformer transformer : var.config.transformers) {
                                FeatureTransformerStrategy strategy = toFeatureTransformerStrategy(transformer);
                                featureTransformerStrategies.add(strategy);
                            }
                        //Need to add combiner here 
                        variations
                                .add(new VariationPredictionStrategy.Variation(
                                        new SimplePredictionStrategy(var.label,
                                                Collections.unmodifiableList(featureTransformerStrategies),
                                                Collections.unmodifiableList(strategies)),
                                        new BigDecimal(var.ratio)));

                    }
                    predictionStore.put(client, VariationPredictionStrategy.build(variations));
                    logger.info("Succesfully added " + variations.size() + " variation test for " + client);
                } else {
                    logger.error("Unknown type for algorithm config");
                }
            } catch (IOException | BeansException e) {
                logger.error("Couldn't update algorithms for client " + client, e);
            }
        }

    }

    @Override
    public void configRemoved(String client, String configKey) {
        if (configKey.equals(ALG_KEY)) {
            predictionStore.remove(client);
            logger.info("Removed client " + client + " from " + ALG_KEY);
        }
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
        StringBuilder builder = new StringBuilder("Available algorithms: \n");
        for (PredictionAlgorithm inc : applicationContext.getBeansOfType(PredictionAlgorithm.class).values()) {
            builder.append('\t');
            builder.append(inc.getClass());
            builder.append('\n');
        }
        builder.append("Available feature transformers: \n");
        for (FeatureTransformer f : applicationContext.getBeansOfType(FeatureTransformer.class).values()) {
            builder.append('\t');
            builder.append(f.getClass());
            builder.append('\n');
        }
        logger.info(builder.toString());

    }

    private PredictionAlgorithmStrategy toAlgorithmStrategy(Algorithm algorithm) {
        PredictionAlgorithm alg = applicationContext.getBean(algorithm.name, PredictionAlgorithm.class);
        Map<String, String> config = toConfigMap(algorithm.config);
        return new PredictionAlgorithmStrategy(alg,
                algorithm.config == null ? new HashMap<String, String>() : config, algorithm.name);
    }

    private FeatureTransformerStrategy toFeatureTransformerStrategy(Transformer transformer) {
        FeatureTransformer t = applicationContext.getBean(transformer.name, FeatureTransformer.class);
        Map<String, String> config = toConfigMap(transformer.config);
        return new FeatureTransformerStrategy(t, transformer.inputCols, transformer.outputCols, config);
    }

    private Map<String, String> toConfigMap(List<ConfigItem> config) {
        Map<String, String> configMap = new HashMap<>();
        if (config == null)
            return configMap;
        for (ConfigItem item : config) {
            configMap.put(item.name, item.value);
        }
        return configMap;
    }

    public abstract static class Strategy {
    }

    // classes for json translation
    public static class TestConfig extends Strategy {
        public List<TestVariation> variations;
    }

    public static class TestVariation {
        public String label;
        public String ratio;
        public AlgorithmConfig config;
    }

    public static class AlgorithmConfig extends Strategy {
        public List<Algorithm> algorithms;
        public List<Transformer> transformers;
        public String combiner;
    }

    public static class ConfigItem {
        public String name;
        public String value;
    }

    public static class Algorithm {
        public String name;
        public List<ConfigItem> config;
    }

    public static class Transformer {
        public String name;
        public List<String> inputCols;
        public List<String> outputCols;
        public List<ConfigItem> config;
    }

}