meka.core.Result.java Source code

Java tutorial

Introduction

Here is the source code for meka.core.Result.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package meka.core;

import weka.core.Instance;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.Attribute;
import weka.core.Utils;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Set;

/**
 * Result - Stores predictions alongside true labels, for evaluation. 
 * For more on the evaluation and threshold selection implemented here; see: 
 * <p>
 * Jesse Read, Bernhard Pfahringer, Geoff Holmes, Eibe Frank. <i>Classifier Chains for Multi-label Classification</i>. Machine Learning Journal. Springer (2011).<br>
 * Jesse Read, <i>Scalable Multi-label Classification</i>. PhD Thesis, University of Waikato, Hamilton, New Zealand (2010).<br>
 * </p>
 * @author    Jesse Read
 * @version   March 2012 - Multi-target Compatible
 */
public class Result implements Serializable {

    private static final long serialVersionUID = 1L;

    /** The number of label (target) variables in the problem */
    public int L = 0;

    public ArrayList<double[]> predictions = null;
    // TODO, store in sparse fashion with either LabelSet or LabelVector
    public ArrayList<int[]> actuals = null;

    public HashMap<String, String> info = new LinkedHashMap<String, String>(); // stores general dataset/classifier info
    public HashMap<String, Object> output = new LinkedHashMap<String, Object>();// stores predictive evaluation statistics
    public HashMap<String, Object> vals = new LinkedHashMap<String, Object>(); // stores non-predictive evaluation stats
    public HashMap<String, String> model = new LinkedHashMap<String, String>(); // stores the model itself

    public Result() {
        predictions = new ArrayList<double[]>();
        actuals = new ArrayList<int[]>();
    }

    public Result(int L) {
        predictions = new ArrayList<double[]>();
        actuals = new ArrayList<int[]>();
        this.L = L;
    }

    public Result(int N, int L) {
        predictions = new ArrayList<double[]>(N);
        actuals = new ArrayList<int[]>(N);
        this.L = L;
    }

    /** The number of value-prediction pairs stared in this Result */
    public int size() {
        return predictions.size();
    }

    /**
     * Provides a nice textual output of all evaluation information.
     * @return   String representation
     */
    @Override
    public String toString() {

        StringBuilder resultString = new StringBuilder();
        if (info.containsKey("Verbosity")) {
            int V = MLUtils.getIntegerOption(info.get("Verbosity"), 1);

            if (V > 4) {
                resultString.append("== Individual Errors\n\n");
                // output everything
                resultString.append(Result.getResultAsString(this, V - 5) + "\n\n");
            }
        }
        // output the stats in general
        if (model.size() > 0)
            resultString.append("== Model info\n\n" + MLUtils.hashMapToString(model));
        resultString.append("== Evaluation Info\n\n" + MLUtils.hashMapToString(info));
        resultString.append("\n\n== Predictive Performance\n\n" + MLUtils.hashMapToString(output, 3));
        String note = "";
        if (info.containsKey("Type") && info.get("Type").endsWith("CV")) {
            note = " (averaged across folds)";
        }
        resultString.append("\n\n== Additional Measurements" + note + "\n\n" + MLUtils.hashMapToString(vals, 3));

        resultString.append("\n\n");
        return resultString.toString();
    }

    /**
     * AddResult - Add an entry.
     * @param pred   predictions
     * @param real  an instance containing the true label values
     */
    public void addResult(double pred[], Instance real) {
        predictions.add(pred);
        actuals.add(MLUtils.toIntArray(real, pred.length));

    }

    /**
     * RowActual - Retrieve the true values for the i-th instance.
     */
    public int[] rowTrue(int i) {
        return actuals.get(i);
    }

    /**
     * RowConfidence - Retrieve the prediction confidences for the i-th instance.
     */
    public double[] rowConfidence(int i) {
        return predictions.get(i);
    }

    /**
     * RowPrediction - Retrieve the predicted values for the i-th instance according to threshold t.
     */
    public int[] rowPrediction(int i, double t) {
        return A.toIntArray(rowConfidence(i), t);
    }

    /**
     * RowPrediction - Retrieve the predicted values for the i-th instance according to pre-calibrated/chosen threshold.
     */
    public int[] rowPrediction(int i) {
        String t = info.get("Threshold");
        if (t != null) {
            // For multi-label data, should know about a threshold first
            return ThresholdUtils.threshold(rowConfidence(i), t);
        } else {
            // Probably multi-target data (no threshold allowed)
            return A.toIntArray(rowConfidence(i));
        }
    }

    /**
     * ColConfidence - Retrieve the prediction confidences for the j-th label (column).
     * Similar to M.getCol(Y,j)
     */
    public double[] colConfidence(int j) {
        double y[] = new double[predictions.size()];
        for (int i = 0; i < predictions.size(); i++) {
            y[i] = rowConfidence(i)[j];
        }
        return y;
    }

    /**
     * AllPredictions - Retrieve all prediction confidences in an L * N matrix (2d array).
     */
    public double[][] allPredictions() {
        double Y[][] = new double[predictions.size()][];
        for (int i = 0; i < predictions.size(); i++) {
            Y[i] = rowConfidence(i);
        }
        return Y;
    }

    /**
     * AllPredictions - Retrieve all predictions (according to threshold t) in an L * N matrix.
     */
    public int[][] allPredictions(double t) {
        int Y[][] = new int[predictions.size()][];
        for (int i = 0; i < predictions.size(); i++) {
            Y[i] = rowPrediction(i, t);
        }
        return Y;
    }

    /**
     * AllTrueValues - Retrieve all true values in an L x N matrix.
     */
    public int[][] allTrueValues() {
        int Y[][] = new int[actuals.size()][];
        for (int i = 0; i < actuals.size(); i++) {
            Y[i] = rowTrue(i);
        }
        return Y;
    }

    /*
     * AddValue.
     * Add v to an existing metric value.
    public void addValue(String metric, double v) {
       Double freq = (Double)vals.get(metric);
       vals.put(metric,(freq == null) ? v : freq + v);
    }
    */

    /**
     * Return the set of metrics for which measurements are available.
     */
    public Set<String> availableMetrics() {
        return output.keySet();
    }

    /**
     * Set the measurement for metric 'metric'.
     */
    public void setMeasurement(String metric, Object stat) {
        output.put(metric, stat);
    }

    /**
     * Retrieve the measurement for metric 'metric'.
     */
    public Object getMeasurement(String metric) {
        return output.get(metric);
    }

    /**
     * SetValue.
     * Add an evaluation metric and a value for it.
     */
    public void setValue(String metric, double v) {
        vals.put(metric, v);
    }

    /**
     * AddValue.
     * Retrieve the value for metric 'metric'
     */
    public Object getValue(String metric) {
        return vals.get(metric);
    }

    /**
     * SetInfo.
     * Set a String value to an information category.
     */
    public void setInfo(String cat, String val) {
        info.put(cat, val);
    }

    /**
     * GetInfo.
     * Get the String value of category 'cat'.
     */
    public String getInfo(String cat) {
        return info.get(cat);
    }

    /**
     * Set a model string.
     */
    public void setModel(String key, String val) {
        model.put(key, val);
    }

    /**
     * Get the model value.
     */
    public String getModel(String key) {
        return model.get(key);
    }

    // ********************************************************************************************************
    //                     STATIC METHODS
    // ********************************************************************************************************
    //

    /**
     * GetStats.
     * Return the evaluation statistics given predictions and real values stored in r.
     * In the multi-label case, a Threshold category must exist, containing a string defining the type of threshold we want to use/calibrate.
     */
    public static HashMap<String, Object> getStats(Result r, String vop) {
        if (r.getInfo("Type").startsWith("MT"))
            return MLEvalUtils.getMTStats(r.allPredictions(), r.allTrueValues(), vop);
        else
            return MLEvalUtils.getMLStats(r.allPredictions(), r.allTrueValues(), r.getInfo("Threshold"), vop);
    }

    /**
     * GetResultAsString - print out each prediction in a Result along with its true labelset.
     */
    public static String getResultAsString(Result s) {
        return getResultAsString(s, 3);
    }

    /**
     * WriteResultToFile -- write a Result 'result' out in plain text format to file 'fname'.
     * @param result Result
     * @param fname file name
     */
    public static void writeResultToFile(Result result, String fname) throws Exception {
        PrintWriter outer = new PrintWriter(new BufferedWriter(new FileWriter(fname)));
        outer.write(result.toString());
        outer.close();
    }

    /**
     * Convert a list of Results into an Instances.
     * @param results An ArrayList of Results
     * @return   Instances
     */
    public static Instances getResultsAsInstances(ArrayList<HashMap<String, Object>> metrics) {

        HashMap<String, Object> o_master = metrics.get(0);
        ArrayList<Attribute> attInfo = new ArrayList<Attribute>();
        for (String key : o_master.keySet()) {
            if (o_master.get(key) instanceof Double) {
                //System.out.println("key="+key);
                attInfo.add(new Attribute(key));
            }
        }

        Instances resultInstances = new Instances("Results", attInfo, metrics.size());

        for (HashMap<String, Object> o : metrics) {
            Instance rx = new DenseInstance(attInfo.size());
            for (Attribute att : attInfo) {
                String name = att.name();
                rx.setValue(att, (double) o.get(name));
            }
            resultInstances.add(rx);
        }

        //System.out.println(""+resultInstances);
        return resultInstances;

    }

    /**
     * GetResultAsString - print out each prediction in a Result (to a certain number of decimal points) along with its true labelset.
     */
    public static String getResultAsString(Result result, int adp) {
        StringBuilder sb = new StringBuilder();
        double N = (double) result.predictions.size();
        sb.append("|==== PREDICTIONS (N=" + N + ") =====>\n");
        for (int i = 0; i < N; i++) {
            sb.append("|");
            sb.append(Utils.doubleToString((i + 1), 5, 0));
            sb.append(" ");
            //System.out.println(""+result.info.get("Threshold"));
            //System.out.println("|"+A.toString(result.rowPrediction(i)));
            //System.out.println("|"+MLUtils.toIndicesSet(result.rowPrediction(i)));
            if (adp == 0 && !result.getInfo("Type").equalsIgnoreCase("MT")) {
                LabelSet y = new LabelSet(MLUtils.toIndicesSet(result.actuals.get(i)));
                sb.append(y).append(" ");
                LabelSet ypred = new LabelSet(MLUtils.toIndicesSet(result.rowPrediction(i)));
                sb.append(ypred).append("\n");
            } else {
                sb.append(A.toString(result.actuals.get(i))).append(" ");
                sb.append(A.toString(result.predictions.get(i), adp)).append("\n");
            }
        }
        sb.append("|==============================<\n");
        return sb.toString();
    }

}