io.seldon.api.state.ClientAlgorithmStore.java Source code

Java tutorial

Introduction

Here is the source code for io.seldon.api.state.ClientAlgorithmStore.java

Source

/*
 * Seldon -- open source prediction engine
 * =======================================
 *
 * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
 *
 * ********************************************************************************************
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * ********************************************************************************************
 */

package io.seldon.api.state;

import io.seldon.clustering.recommender.ItemRecommendationAlgorithm;
import io.seldon.recommendation.AlgorithmStrategy;
import io.seldon.recommendation.ClientStrategy;
import io.seldon.recommendation.ItemFilter;
import io.seldon.recommendation.ItemIncluder;
import io.seldon.recommendation.JsOverrideClientStrategy;
import io.seldon.recommendation.RecTagClientStrategy;
import io.seldon.recommendation.SimpleClientStrategy;
import io.seldon.recommendation.VariationTestingClientStrategy;
import io.seldon.recommendation.combiner.AlgorithmResultsCombiner;
import io.seldon.recommendation.filters.base.CurrentItemFilter;
import io.seldon.recommendation.filters.base.IgnoredRecsFilter;
import io.seldon.recommendation.filters.base.RecentImpressionsFilter;

import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import javax.annotation.PostConstruct;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.common.collect.Sets;

/**
 * Cache of which algorithms to use for which clients. Receives updates via ClientConfigUpdateListener
 *
 * @author firemanphil
 *         Date: 27/11/14
 *         Time: 13:55
 */
@Component
public class ClientAlgorithmStore
        implements ApplicationContextAware, ClientConfigUpdateListener, GlobalConfigUpdateListener {

    private static final String RECTAG = "alg_rectags";
    protected static Logger logger = Logger.getLogger(ClientAlgorithmStore.class.getName());

    private static final String ALG_KEY = "algs";
    private static final String TESTING_SWITCH_KEY = "alg_test_switch";
    private static final String TEST = "alg_test";

    private final ClientConfigHandler configHandler;
    private final GlobalConfigHandler globalConfigHandler;
    private final Set<ItemFilter> alwaysOnFilters;
    private ApplicationContext applicationContext;
    private ConcurrentMap<String, ClientStrategy> store = new ConcurrentHashMap<>();
    private ConcurrentMap<String, Map<String, AlgorithmStrategy>> storeMap = new ConcurrentHashMap<>();
    private ConcurrentMap<String, Boolean> testingOnOff = new ConcurrentHashMap<>();
    private ConcurrentMap<String, ClientStrategy> tests = new ConcurrentHashMap<>();
    private ConcurrentMap<String, ClientStrategy> recTagStrategies = new ConcurrentHashMap<>();
    private ClientStrategy defaultStrategy = null;
    private ConcurrentMap<String, ClientStrategy> namedStrategies = new ConcurrentHashMap<>();

    @Autowired
    public ClientAlgorithmStore(ClientConfigHandler configHandler, GlobalConfigHandler globalConfigHandler,
            CurrentItemFilter currentItemFilter, IgnoredRecsFilter ignoredRecsFilter,
            RecentImpressionsFilter recentImpressionsFilter) {
        this.configHandler = configHandler;
        this.globalConfigHandler = globalConfigHandler;
        Set<ItemFilter> set = new HashSet<>();
        set.add(currentItemFilter);
        set.add(ignoredRecsFilter);
        set.add(recentImpressionsFilter);
        alwaysOnFilters = Collections.unmodifiableSet(set);
    }

    @PostConstruct
    private void init() {
        logger.info("Initializing...");
        configHandler.addListener(this);
        globalConfigHandler.addSubscriber("default_strategy", this);
        //        globalConfigHandler.addSubscriber("named_strategies", this);
    }

    public ClientStrategy retrieveStrategy(String client, Collection<String> algs) {
        ClientStrategy originalStrat = retrieveStrategy(client);
        return new JsOverrideClientStrategy(originalStrat, algs, applicationContext);
    }

    public ClientStrategy retrieveStrategy(String client) {
        if (testRunning(client)) {
            ClientStrategy strategy = tests.get(client);
            if (strategy != null) {
                return strategy;
            } else {
                logger.warn("Testing was switch on for client " + client + " but no test was specified."
                        + " Returning default strategy");
                return defaultStrategy;
            }
        } else {
            ClientStrategy strategy;
            if (recTagStrategies.containsKey(client))
                strategy = recTagStrategies.get(client);
            else
                strategy = store.get(client);
            if (strategy != null) {
                return strategy;
            } else {
                return defaultStrategy;
            }
        }
    }

    @Override
    public void configRemoved(String client, String configKey) {
        logger.info("Received config remove for " + client + " with key " + configKey);
        if (configKey.equals(ALG_KEY)) {
            store.remove(client);
            logger.info("Successfully removed " + client + " from " + ALG_KEY);
        } else if (configKey.equals(TESTING_SWITCH_KEY)) {
            testingOnOff.remove(client);
            logger.info("Successfully removed " + client + " from " + TESTING_SWITCH_KEY);
        } else if (configKey.equals(TEST)) {
            tests.remove(client);
            logger.info("Successfully removed " + client + " from " + TEST);
        } else if (configKey.equals(RECTAG)) {
            recTagStrategies.remove(client);
            logger.info("Successfully removed " + client + " from " + RECTAG);
        } else
            logger.warn("Ignored unknow config remove for " + client + " with key " + configKey);
    }

    @Override
    public void configUpdated(String client, String configKey, String configValue) {
        SimpleModule module = new SimpleModule("StrategyDeserializerModule");
        module.addDeserializer(Strategy.class, new StrategyDeserializer());
        ObjectMapper mapper = new ObjectMapper();
        mapper.registerModule(module);
        if (configKey.equals(ALG_KEY)) {
            logger.info("Received new algorithm config for " + client + ": " + configValue);
            try {
                List<AlgorithmStrategy> strategies = new ArrayList<>();
                Map<String, AlgorithmStrategy> stratMap = new HashMap<>();
                AlgorithmConfig config = mapper.readValue(configValue, AlgorithmConfig.class);
                for (Algorithm algorithm : config.algorithms) {
                    AlgorithmStrategy strategy = toAlgorithmStrategy(algorithm);
                    strategies.add(strategy);
                }
                AlgorithmResultsCombiner combiner = applicationContext.getBean(config.combiner,
                        AlgorithmResultsCombiner.class);
                Map<Integer, Double> actionWeightMap = toActionWeightMap(config.actionWeights);
                store.put(client, new SimpleClientStrategy(Collections.unmodifiableList(strategies), combiner,
                        config.diversityLevel, ClientStrategy.DEFAULT_NAME, actionWeightMap));
                storeMap.put(client, Collections.unmodifiableMap(stratMap));
                logger.info("Successfully added new algorithm config for " + client);
            } catch (IOException | BeansException e) {
                logger.error("Couldn't update algorithms for client " + client, e);
            }
        } else if (configKey.equals(TESTING_SWITCH_KEY)) {
            // not json as its so simple
            Boolean onOff = BooleanUtils.toBooleanObject(configValue);
            if (onOff == null) {
                logger.error("Couldn't set testing switch for client " + client + ", input was " + configValue);
            } else {
                logger.info("Testing switch for client " + client + " moving from '"
                        + BooleanUtils.toStringOnOff(testingOnOff.get(client)) + "' to '"
                        + BooleanUtils.toStringOnOff(onOff) + "'");
                testingOnOff.put(client, BooleanUtils.toBooleanObject(configValue));
            }
        } else if (configKey.equals(TEST)) {
            logger.info("Received new testing config for " + client + ":" + configValue);
            try {
                TestConfig config = mapper.readValue(configValue, TestConfig.class);
                List<VariationTestingClientStrategy.Variation> variations = new ArrayList<>();
                for (TestVariation var : config.variations) {
                    List<AlgorithmStrategy> strategies = new ArrayList<>();
                    for (Algorithm alg : var.config.algorithms) {
                        AlgorithmStrategy strategy = toAlgorithmStrategy(alg);
                        strategies.add(strategy);
                    }
                    AlgorithmResultsCombiner combiner = applicationContext.getBean(var.config.combiner,
                            AlgorithmResultsCombiner.class);
                    Map<Integer, Double> actionWeightMap = toActionWeightMap(var.config.actionWeights);
                    variations.add(new VariationTestingClientStrategy.Variation(
                            new SimpleClientStrategy(Collections.unmodifiableList(strategies), combiner,
                                    var.config.diversityLevel, var.label, actionWeightMap),
                            new BigDecimal(var.ratio)));

                }
                tests.put(client, VariationTestingClientStrategy.build(variations));
                logger.info("Succesfully added " + variations.size() + " variation test for " + client);
            } catch (NumberFormatException | IOException e) {
                logger.error("Couldn't add test for client " + client, e);
            }
        } else if (configKey.equals(RECTAG)) {
            logger.info("Received new rectag config for " + client + ": " + configValue);
            try {
                RecTagConfig config = mapper.readValue(configValue, RecTagConfig.class);
                if (config.defaultStrategy == null) {
                    logger.error("Couldn't read rectag config as there was no default alg");
                    return;
                }

                ClientStrategy defStrategy = toStrategy(config.defaultStrategy);
                Map<String, ClientStrategy> recTagStrats = new HashMap<>();
                for (Map.Entry<String, Strategy> entry : config.recTagToStrategy.entrySet()) {
                    recTagStrats.put(entry.getKey(), toStrategy(entry.getValue()));
                }
                recTagStrategies.put(client, new RecTagClientStrategy(defStrategy, recTagStrats));
                logger.info("Successfully added rec tag strategy for " + client);
            } catch (NumberFormatException | IOException e) {
                logger.error("Couldn't add rectag strategy for client " + client, e);
            }

        }
    }

    private ClientStrategy toStrategy(Strategy jsonStrategy) {
        if (jsonStrategy instanceof TestConfig) {
            TestConfig jsonStrategyTest = (TestConfig) jsonStrategy;
            List<VariationTestingClientStrategy.Variation> variations = new ArrayList<>();
            for (TestVariation var : jsonStrategyTest.variations) {
                List<AlgorithmStrategy> strategies = new ArrayList<>();
                for (Algorithm alg : var.config.algorithms) {
                    AlgorithmStrategy strategy = toAlgorithmStrategy(alg);
                    strategies.add(strategy);
                }
                AlgorithmResultsCombiner combiner = applicationContext.getBean(var.config.combiner,
                        AlgorithmResultsCombiner.class);
                Map<Integer, Double> actionWeightMap = toActionWeightMap(var.config.actionWeights);
                variations
                        .add(new VariationTestingClientStrategy.Variation(
                                new SimpleClientStrategy(Collections.unmodifiableList(strategies), combiner,
                                        var.config.diversityLevel, var.label, actionWeightMap),
                                new BigDecimal(var.ratio)));

            }
            return VariationTestingClientStrategy.build(variations);

        } else {
            AlgorithmConfig jsonStrategyAlg = (AlgorithmConfig) jsonStrategy;
            List<AlgorithmStrategy> defaultAlgStrategies = new ArrayList<>();
            for (Algorithm alg : jsonStrategyAlg.algorithms) {
                AlgorithmStrategy strategy = toAlgorithmStrategy(alg);
                defaultAlgStrategies.add(strategy);
            }
            AlgorithmResultsCombiner defCombiner = applicationContext.getBean(jsonStrategyAlg.combiner,
                    AlgorithmResultsCombiner.class);
            Map<Integer, Double> defActionWeightMap = toActionWeightMap(jsonStrategyAlg.actionWeights);
            return new SimpleClientStrategy(defaultAlgStrategies, defCombiner, jsonStrategyAlg.diversityLevel, "-",
                    defActionWeightMap);
        }
    }

    private AlgorithmStrategy toAlgorithmStrategy(Algorithm algorithm) {
        Set<ItemIncluder> includers = retrieveIncluders(algorithm.includers);
        Set<ItemFilter> filters = retrieveFilters(algorithm.filters);
        ItemRecommendationAlgorithm alg = applicationContext.getBean(algorithm.name,
                ItemRecommendationAlgorithm.class);
        Map<String, String> config = toConfigMap(algorithm.config);
        return new AlgorithmStrategy(alg, includers, filters,
                algorithm.config == null ? new HashMap<String, String>() : config, algorithm.name);
    }

    private Map<String, String> toConfigMap(List<ConfigItem> config) {
        Map<String, String> configMap = new HashMap<>();
        if (config == null)
            return configMap;
        for (ConfigItem item : config) {
            configMap.put(item.name, item.value);
        }
        return configMap;
    }

    private Map<Integer, Double> toActionWeightMap(List<ActionWeightItem> weights) {
        Map<Integer, Double> weightMap = new HashMap<Integer, Double>();
        if (weights == null)
            return Collections.unmodifiableMap(weightMap);
        for (ActionWeightItem item : weights) {
            weightMap.put(Integer.parseInt(item.type), item.value);
        }
        return Collections.unmodifiableMap(weightMap);
    }

    private Set<ItemIncluder> retrieveIncluders(List<String> includers) {

        Set<ItemIncluder> includerSet = new HashSet<>();
        if (includers == null)
            return includerSet;
        for (String includer : includers) {
            includerSet.add(applicationContext.getBean(includer, ItemIncluder.class));
        }
        return includerSet;
    }

    private Set<ItemFilter> retrieveFilters(List<String> filters) {
        Set<ItemFilter> filterSet = new HashSet<>();
        if (filters == null)
            return alwaysOnFilters;
        for (String filter : filters) {
            filterSet.add(applicationContext.getBean(filter, ItemFilter.class));
        }
        return Sets.union(filterSet, alwaysOnFilters);
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
        StringBuilder builder = new StringBuilder("Available algorithms: \n");
        for (ItemRecommendationAlgorithm inc : applicationContext.getBeansOfType(ItemRecommendationAlgorithm.class)
                .values()) {
            builder.append('\t');
            builder.append(inc.getClass());
            builder.append('\n');
        }
        logger.info(builder.toString());
        builder = new StringBuilder("Available includers: \n");
        for (ItemIncluder inc : applicationContext.getBeansOfType(ItemIncluder.class).values()) {
            builder.append('\t');
            builder.append(inc.getClass());
            builder.append('\n');
        }
        logger.info(builder.toString());
        builder = new StringBuilder("Available filters: \n");
        for (ItemFilter filt : applicationContext.getBeansOfType(ItemFilter.class).values()) {
            builder.append('\t');
            builder.append(filt.getClass());
            builder.append('\n');

        }
        logger.info(builder.toString());
        for (AlgorithmResultsCombiner filt : applicationContext.getBeansOfType(AlgorithmResultsCombiner.class)
                .values()) {
            builder.append('\t');
            builder.append(filt.getClass());
            builder.append('\n');

        }
        builder = new StringBuilder("Available combiners: \n");
        logger.info(builder.toString());
    }

    private boolean testRunning(String client) {
        return testingOnOff.get(client) != null && testingOnOff.get(client) && tests.get(client) != null;
    }

    @Override
    public void configUpdated(String configKey, String configValue) {
        configValue = StringUtils.strip(configValue);
        logger.info("KEY WAS " + configKey);
        logger.info("Received new default strategy: " + configValue);

        if (StringUtils.length(configValue) == 0) {
            logger.warn("*WARNING* no default strategy is set!");
        } else {
            try {
                ObjectMapper mapper = new ObjectMapper();
                List<AlgorithmStrategy> strategies = new ArrayList<>();
                AlgorithmConfig config = mapper.readValue(configValue, AlgorithmConfig.class);

                for (Algorithm alg : config.algorithms) {
                    strategies.add(toAlgorithmStrategy(alg));
                }
                AlgorithmResultsCombiner combiner = applicationContext.getBean(config.combiner,
                        AlgorithmResultsCombiner.class);
                Map<Integer, Double> actionWeightMap = toActionWeightMap(config.actionWeights);
                ClientStrategy strat = new SimpleClientStrategy(strategies, combiner, config.diversityLevel, "-",
                        actionWeightMap);
                defaultStrategy = strat;
                logger.info("Successfully changed default strategy.");
            } catch (IOException e) {
                logger.error("Problem changing default strategy ", e);
            }
        }
    }

    // classes for json translation
    public static class TestConfig extends Strategy {
        public List<TestVariation> variations;
    }

    public static class RecTagConfig {
        public Map<String, Strategy> recTagToStrategy;
        public Strategy defaultStrategy;
    }

    private static class TestVariation {
        public String label;
        public String ratio;
        public AlgorithmConfig config;
    }

    public abstract static class Strategy {
    }

    public static class AlgorithmConfig extends Strategy {
        public List<Algorithm> algorithms;
        public String combiner;
        public Double diversityLevel;
        public List<ActionWeightItem> actionWeights;
    }

    public static class ConfigItem {
        public String name;
        public String value;
    }

    public static class ActionWeightItem {
        public String type;
        public Double value;
    }

    public static class Algorithm {

        public String name;
        public List<String> includers;
        public List<String> filters;
        public List<ConfigItem> config;
    }

}