edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.ClassifierProvider.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.ClassifierProvider.java

Source

/*
 * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations
 * under the License.
 */

package edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers;

import com.google.common.base.Charsets;
import com.google.common.collect.*;
import com.google.common.io.Files;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.resource.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.function.Function;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.*;

/**
 * The interface for wrapping a classifier.
 * It is used by both general-purpose
 * {@link edu.cmu.lti.oaqa.baseqa.learning_base.ClassifierTrainer} and
 * {@link edu.cmu.lti.oaqa.baseqa.learning_base.ClassifierPredictor} classes, and specific
 * classifier training and prediction classes, e.g.
 * {@link edu.cmu.lti.oaqa.baseqa.answer_type.AnswerTypeClassifierTrainer} and
 * {@link edu.cmu.lti.oaqa.baseqa.answer_type.AnswerTypeClassifierPredictor}.
 *
 * @see edu.cmu.lti.oaqa.baseqa.learning_base.ClassifierTrainer
 * @see edu.cmu.lti.oaqa.baseqa.learning_base.ClassifierPredictor
 *
 * @author <a href="mailto:ziy@cs.cmu.edu">Zi Yang</a> created on 4/5/15
 */
public interface ClassifierProvider extends Resource {

    enum ResampleType {
        DOWN, UP, NONE
    }

    Logger LOG = LoggerFactory.getLogger(ClassifierProvider.class);

    Map<String, Double> infer(Map<String, Double> features) throws AnalysisEngineProcessException;

    default double infer(Map<String, Double> features, String label) throws AnalysisEngineProcessException {
        return infer(features).get(label);
    }

    default String predict(Map<String, Double> features) throws AnalysisEngineProcessException {
        return infer(features).entrySet().stream().max(Comparator.comparing(Map.Entry::getValue))
                .orElseThrow(AnalysisEngineProcessException::new).getKey();
    }

    default List<String> predict(Map<String, Double> features, int k) throws AnalysisEngineProcessException {
        return infer(features).entrySet().stream()
                .sorted(Comparator.comparing(Map.Entry::getValue, Comparator.reverseOrder())).map(Map.Entry::getKey)
                .limit(k).collect(toList());
    }

    void train(List<Map<String, Double>> X, List<String> Y, boolean crossValidation)
            throws AnalysisEngineProcessException;

    default void trainMultiLabel(List<Map<String, Double>> X, List<Collection<String>> Y, ResampleType resampleType,
            boolean crossValidation) throws AnalysisEngineProcessException {
        int size = X.size();
        assert size == Y.size();
        List<Map<String, Double>> XX = new ArrayList<>();
        List<String> YY = new ArrayList<>();
        IntStream.range(0, size).forEach(i -> {
            Collection<String> y = Y.get(i);
            YY.addAll(y);
            XX.addAll(Collections.nCopies(y.size(), X.get(i)));
        });
        train(XX, YY, resampleType, crossValidation);
    }

    default void train(List<Map<String, Double>> X, List<String> Y, ResampleType resampleType,
            boolean crossValidation) throws AnalysisEngineProcessException {
        switch (resampleType) {
        case DOWN: {
            Map<String, Long> y2count = Y.stream().collect(groupingBy(Function.identity(), counting()));
            double yMin = Collections.min(y2count.values());
            Map<String, Double> y2weight = y2count.entrySet().stream()
                    .collect(toMap(Map.Entry::getKey, entry -> yMin / entry.getValue()));
            Set<Integer> indexes = IntStream.range(0, Y.size()).filter(i -> Math.random() < y2weight.get(Y.get(i)))
                    .boxed().collect(toSet());
            List<Map<String, Double>> XS = indexes.stream().map(X::get).collect(toList());
            List<String> YS = indexes.stream().map(Y::get).collect(toList());
            train(XS, YS, crossValidation);
            break;
        }
        case UP: {
            Map<String, Long> y2count = Y.stream().collect(groupingBy(Function.identity(), counting()));
            double yMax = Collections.max(y2count.values());
            Map<String, Double> y2weight = y2count.entrySet().stream()
                    .collect(toMap(Map.Entry::getKey, entry -> yMax / entry.getValue()));
            Multiset<Integer> indexes = HashMultiset.create();
            // base "integer" count "up-sampling"
            IntStream.range(0, Y.size()).forEach(i -> indexes.setCount(i, y2weight.get(Y.get(i)).intValue()));
            // additional "decimal" part count
            IntStream.range(0, Y.size()).filter(i -> Math.random() < y2weight.get(Y.get(i)) % 1)
                    .forEach(indexes::add);
            List<Map<String, Double>> XS = indexes.stream().map(X::get).collect(toList());
            List<String> YS = indexes.stream().map(Y::get).collect(toList());
            train(XS, YS, crossValidation);
            break;
        }
        case NONE: {
            train(X, Y, crossValidation);
            break;
        }
        }
    }

    default List<Double> crossTrainInferMultiLabel(List<Map<String, Double>> X, List<Collection<String>> Y,
            ResampleType resampleType, String label) throws AnalysisEngineProcessException {
        Set<Integer> indexes = IntStream.range(0, X.size()).boxed().collect(toSet());
        List<Integer> indexList = new ArrayList<>(indexes);
        Collections.shuffle(indexList);
        int nfolds = (int) Math.ceil(indexList.size() / 10.0);
        List<Double> ret = IntStream.range(0, X.size()).mapToObj(i -> Double.NaN).collect(toList());
        for (List<Integer> cvTestIndexes : Lists.partition(indexList, nfolds)) {
            List<Map<String, Double>> cvTrainX = new ArrayList<>();
            List<Collection<String>> cvTrainY = new ArrayList<>();
            Sets.difference(indexes, new HashSet<>(cvTestIndexes)).forEach(cvTrainIndex -> {
                cvTrainX.add(X.get(cvTrainIndex));
                cvTrainY.add(Y.get(cvTrainIndex));
            });
            trainMultiLabel(cvTrainX, cvTrainY, resampleType, false);
            for (int cvTestIndex : cvTestIndexes) {
                double result = infer(X.get(cvTestIndex), label);
                ret.set(cvTestIndex, result);
            }
        }
        return ret;
    }

    default List<List<String>> crossTrainPredictMultiLabel(List<Map<String, Double>> X, List<Collection<String>> Y,
            ResampleType resampleType, int limit) throws AnalysisEngineProcessException {
        Set<Integer> indexes = IntStream.range(0, X.size()).boxed().collect(toSet());
        List<Integer> indexList = new ArrayList<>(indexes);
        Collections.shuffle(indexList);
        int nfolds = (int) Math.ceil(indexList.size() / 10.0);
        List<List<String>> ret = IntStream.range(0, X.size()).mapToObj(i -> new ArrayList<String>())
                .collect(toList());
        int fold = 1;
        for (List<Integer> cvTestIndexes : Lists.partition(indexList, nfolds)) {
            LOG.info("Train Predict Fold {}", fold++);
            List<Map<String, Double>> cvTrainX = new ArrayList<>();
            List<Collection<String>> cvTrainY = new ArrayList<>();
            Sets.difference(indexes, new HashSet<>(cvTestIndexes)).forEach(cvTrainIndex -> {
                cvTrainX.add(X.get(cvTrainIndex));
                cvTrainY.add(Y.get(cvTrainIndex));
            });
            trainMultiLabel(cvTrainX, cvTrainY, resampleType, false);
            for (int cvTestIndex : cvTestIndexes) {
                List<String> result = predict(X.get(cvTestIndex), limit).stream().collect(toList());
                ret.set(cvTestIndex, result);
            }
        }
        return ret;
    }

    default List<Double> crossTrainInfer(List<Map<String, Double>> X, List<String> Y, ResampleType resampleType,
            String label) throws AnalysisEngineProcessException {
        Set<Integer> indexes = IntStream.range(0, X.size()).boxed().collect(toSet());
        List<Integer> indexList = new ArrayList<>(indexes);
        Collections.shuffle(indexList);
        int nfolds = (int) Math.ceil(indexList.size() / 10.0);
        List<Double> ret = IntStream.range(0, X.size()).mapToObj(i -> Double.NaN).collect(toList());
        int fold = 1;
        for (List<Integer> cvTestIndexes : Lists.partition(indexList, nfolds)) {
            LOG.info("Train Predict Fold {}", fold++);
            List<Map<String, Double>> cvTrainX = new ArrayList<>();
            List<String> cvTrainY = new ArrayList<>();
            Sets.difference(indexes, new HashSet<>(cvTestIndexes)).forEach(cvTrainIndex -> {
                cvTrainX.add(X.get(cvTrainIndex));
                cvTrainY.add(Y.get(cvTrainIndex));
            });
            train(cvTrainX, cvTrainY, resampleType, false);
            for (int cvTestIndex : cvTestIndexes) {
                double result = infer(X.get(cvTestIndex), label);
                ret.set(cvTestIndex, result);
            }
        }
        return ret;
    }

    default List<List<String>> crossTrainPredict(List<Map<String, Double>> X, List<String> Y,
            ResampleType resampleType, int limit) throws AnalysisEngineProcessException {
        Set<Integer> indexes = IntStream.range(0, X.size()).boxed().collect(toSet());
        List<Integer> indexList = new ArrayList<>(indexes);
        Collections.shuffle(indexList);
        int nfolds = (int) Math.ceil(indexList.size() / 10.0);
        List<List<String>> ret = IntStream.range(0, X.size()).mapToObj(i -> new ArrayList<String>())
                .collect(toList());
        for (List<Integer> cvTestIndexes : Lists.partition(indexList, nfolds)) {
            List<Map<String, Double>> cvTrainX = new ArrayList<>();
            List<String> cvTrainY = new ArrayList<>();
            Sets.difference(indexes, new HashSet<>(cvTestIndexes)).forEach(cvTrainIndex -> {
                cvTrainX.add(X.get(cvTrainIndex));
                cvTrainY.add(Y.get(cvTrainIndex));
            });
            train(cvTrainX, cvTrainY, resampleType, false);
            for (int cvTestIndex : cvTestIndexes) {
                List<String> result = predict(X.get(cvTestIndex), limit).stream().collect(toList());
                ret.set(cvTestIndex, result);
            }
        }
        return ret;
    }

    static List<String> featureNames(List<Map<String, Double>> X) {
        return X.stream().map(Map::keySet).flatMap(Set::stream).distinct().collect(toList());
    }

    static Map<Integer, String> createFeatureIdKeyMap(List<Map<String, Double>> X) {
        List<String> feats = featureNames(X);
        return IntStream.range(0, feats.size()).boxed().collect(toMap(i -> i + 1, feats::get));
    }

    static List<String> featureNames(List<Map<String, Double>> X, int frequencyThreshold) {
        Map<String, Long> feat2count = X.stream().map(Map::entrySet).flatMap(Set::stream)
                .collect(groupingBy(Map.Entry::getKey, counting()));
        return feat2count.entrySet().stream().filter(e -> e.getValue() >= frequencyThreshold).map(Map.Entry::getKey)
                .collect(toList());
    }

    static List<String> labelNames(List<String> Y) {
        return Y.stream().distinct().collect(toList());
    }

    static BiMap<Integer, String> createLabelIdKeyMap(List<String> Y) {
        List<String> labels = labelNames(Y);
        BiMap<Integer, String> lid2label = HashBiMap.create();
        IntStream.range(0, labels.size()).forEach(i -> lid2label.put(i + 1, labels.get(i)));
        return lid2label;
    }

    static void saveIdKeyMap(Map<Integer, String> id2key, File idKeyMapFile) throws IOException {
        String lines = id2key.entrySet().stream().map(entry -> entry.getKey() + "\t" + entry.getValue())
                .collect(joining("\n"));
        Files.write(lines, idKeyMapFile, Charsets.UTF_8);
    }

    static Map<Integer, String> loadIdKeyMap(File idKeyMapFile) throws IOException {
        return Files.readLines(idKeyMapFile, Charsets.UTF_8).stream().map(line -> line.split("\t"))
                .collect(toMap(segs -> Integer.parseInt(segs[0]), segs -> segs[1]));
    }

}