jjj.asap.sas.ensemble.impl.CrossValidatedEnsemble.java Source code

Java tutorial

Introduction

Here is the source code for jjj.asap.sas.ensemble.impl.CrossValidatedEnsemble.java

Source

/*
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or GITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * Copyright (C) 2012 James Jesensky
 */

package jjj.asap.sas.ensemble.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import jjj.asap.sas.ensemble.Ensemble;
import jjj.asap.sas.ensemble.Scheme;
import jjj.asap.sas.ensemble.StrongLearner;
import jjj.asap.sas.ensemble.WeakLearner;
import jjj.asap.sas.util.Calc;
import jjj.asap.sas.util.Contest;
import jjj.asap.sas.util.Job;
import jjj.asap.sas.weka.DatasetBuilder;
import weka.core.DenseInstance;
import weka.core.Instances;

/**
 * A meta ensemble that uses cross validation to estimate the worth of the
 * wrapped ensemble.
 */
public class CrossValidatedEnsemble implements Scheme {

    private Ensemble ensemble;
    private int nFolds;

    /**
     * @param ensemble based ensemble
     * @param nFolds how many cv folds
     */
    public CrossValidatedEnsemble(Ensemble ensemble, int nFolds) {
        super();
        this.ensemble = ensemble;
        this.nFolds = nFolds;
    }

    @Override
    public StrongLearner build(int essaySet, String ensembleName, List<WeakLearner> learners) {

        // can't handle empty case
        if (learners.isEmpty()) {
            return this.ensemble.build(essaySet, ensembleName, learners);
        }

        // create a dummy dataset.
        DatasetBuilder builder = new DatasetBuilder();
        builder.addVariable("id");
        builder.addNominalVariable("class", Contest.getRubrics(essaySet));
        Instances dummy = builder.getDataset("dummy");

        // add data
        Map<Double, Double> groundTruth = Contest.getGoldStandard(essaySet);
        for (double id : learners.get(0).getPreds().keySet()) {
            dummy.add(new DenseInstance(1.0, new double[] { id, groundTruth.get(id) }));
        }

        // stratify
        dummy.sort(0);
        dummy.randomize(new Random(1));
        dummy.setClassIndex(1);
        dummy.stratify(nFolds);

        // now evaluate each fold
        Map<Double, Double> preds = new HashMap<Double, Double>();
        for (int k = 0; k < nFolds; k++) {
            Instances train = dummy.trainCV(nFolds, k);
            Instances test = dummy.testCV(nFolds, k);

            List<WeakLearner> cvLeaners = new ArrayList<WeakLearner>();
            for (WeakLearner learner : learners) {
                WeakLearner copy = learner.copyOf();
                for (int i = 0; i < test.numInstances(); i++) {
                    copy.getPreds().remove(test.instance(i).value(0));
                    copy.getProbs().remove(test.instance(i).value(0));
                }
                cvLeaners.add(copy);
            }

            // train on fold
            StrongLearner cv = this.ensemble.build(essaySet, ensembleName, cvLeaners);

            List<WeakLearner> testLeaners = new ArrayList<WeakLearner>();
            for (WeakLearner learner : cv.getLearners()) {
                WeakLearner copy = learner.copyOf();
                copy.getPreds().clear();
                copy.getProbs().clear();
                WeakLearner source = find(copy.getName(), learners);
                for (int i = 0; i < test.numInstances(); i++) {
                    double id = test.instance(i).value(0);
                    copy.getPreds().put(id, source.getPreds().get(id));
                    copy.getProbs().put(id, source.getProbs().get(id));
                }
                testLeaners.add(copy);
            }

            preds.putAll(this.ensemble.classify(essaySet, ensembleName, testLeaners, cv.getContext()));
        }

        // now prepare final result

        StrongLearner strong = this.ensemble.build(essaySet, ensembleName, learners);

        double trainingError = strong.getKappa();
        double cvError = Calc.kappa(essaySet, preds, groundTruth);
        //   Job.log(essaySet+"-"+ensembleName, "XVAL: training error = " + trainingError + " cv error = " + cvError);      

        strong.setKappa(cvError);
        return strong;
    }

    private WeakLearner find(String name, List<WeakLearner> learners) {
        for (WeakLearner learner : learners) {
            if (learner.getName().equals(name)) {
                return learner;
            }
        }
        return null;
    }

    @Override
    public Map<Double, Double> classify(int essaySet, String ensembleName, List<WeakLearner> learners,
            Object context) {
        return this.ensemble.classify(essaySet, ensembleName, learners, context);
    }

}