Example usage for org.apache.mahout.classifier.sgd ModelDissector update

List of usage examples for org.apache.mahout.classifier.sgd ModelDissector update

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd ModelDissector update.

Prototype


public void update(Vector features, Map<String, Set<Integer>> traceDictionary,
        AbstractVectorClassifier learner) 

Source Link

Document

Probes a model to determine the effect of a particular variable.

Usage

From source file:com.memonews.mahout.sentiment.SGDHelper.java

License:Apache License

public static void dissect(final int leakType, final Dictionary dictionary,
        final AdaptiveLogisticRegression learningAlgorithm, final Iterable<File> files,
        final Multiset<String> overallCounts) throws IOException {
    final CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
    model.close();/*from   ww  w  . j  av a  2s .  co  m*/

    final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
    final ModelDissector md = new ModelDissector();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setTraceDictionary(traceDictionary);
    helper.getBias().setTraceDictionary(traceDictionary);

    for (final File file : permute(files, helper.getRandom()).subList(0, 500)) {
        traceDictionary.clear();
        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        md.update(v, traceDictionary, model);
    }

    final List<String> ngNames = Lists.newArrayList(dictionary.values());
    final List<ModelDissector.Weight> weights = md.summary(100);
    System.out.println("============");
    System.out.println("Model Dissection");
    for (final ModelDissector.Weight w : weights) {
        System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(),
                ngNames.get(w.getMaxImpact()), w.getCategory(0), w.getWeight(0));
    }
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

private static void dissect(Dictionary newsGroups, AdaptiveLogisticRegression learningAlgorithm,
        Iterable<File> files) throws IOException {
    CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
    model.close();/*w  w w.java 2s .co  m*/

    Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
    ModelDissector md = new ModelDissector();

    encoder.setTraceDictionary(traceDictionary);
    bias.setTraceDictionary(traceDictionary);

    for (File file : permute(files, rand).subList(0, 500)) {
        traceDictionary.clear();
        Vector v = encodeFeatureVector(file);
        md.update(v, traceDictionary, model);
    }

    List<String> ngNames = Lists.newArrayList(newsGroups.values());
    List<ModelDissector.Weight> weights = md.summary(100);
    for (ModelDissector.Weight w : weights) {
        System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(),
                ngNames.get(w.getMaxImpact() + 1), w.getCategory(1), w.getWeight(1), w.getCategory(2),
                w.getWeight(2));
    }
}