org.openml.webapplication.predictionCounter.FoldsPredictionCounter.java Source code

Java tutorial

Introduction

Here is the source code for org.openml.webapplication.predictionCounter.FoldsPredictionCounter.java

Source

/*
 *  Webapplication - Java library that runs on OpenML servers
 *  Copyright (C) 2014 
 *  @author Jan N. van Rijn (j.n.van.rijn@liacs.leidenuniv.nl)
 *  
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *  
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *  
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *  
 */
package org.openml.webapplication.predictionCounter;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.openml.webapplication.algorithm.InstancesHelper;

import weka.core.Instance;
import weka.core.Instances;

public class FoldsPredictionCounter implements PredictionCounter {

    private final int ATT_SPLITS_TYPE;
    private final int ATT_SPLITS_ROWID;
    private final int ATT_SPLITS_REPEAT;
    private final int ATT_SPLITS_FOLD;
    private final int ATT_SPLITS_SAMPLE;

    private final int NR_OF_REPEATS;
    private final int NR_OF_FOLDS;
    private final int NR_OF_SAMPLES;

    private final List<Integer>[][][] expected;
    private final List<Integer>[][][] actual;
    private int expectedTotal;

    private final int[][][] shadowTypeSize;

    private String error_message;

    public FoldsPredictionCounter(Instances splits) throws Exception {
        this(splits, "TEST", "TRAIN");
    }

    @SuppressWarnings("unchecked")
    public FoldsPredictionCounter(Instances splits, String type, String shadowType) throws Exception {
        ATT_SPLITS_TYPE = InstancesHelper.getRowIndex("type", splits);
        ATT_SPLITS_ROWID = InstancesHelper.getRowIndex(new String[] { "rowid", "row_id" }, splits);
        ATT_SPLITS_REPEAT = InstancesHelper.getRowIndex(new String[] { "repeat", "repeat_nr" }, splits);
        ATT_SPLITS_FOLD = InstancesHelper.getRowIndex(new String[] { "fold", "fold_nr" }, splits);
        int att_splits_sample;
        try {
            att_splits_sample = InstancesHelper.getRowIndex(new String[] { "sample", "sample_nr" }, splits);
        } catch (Exception e) {
            att_splits_sample = -1;
        }
        ATT_SPLITS_SAMPLE = att_splits_sample;

        NR_OF_REPEATS = splits.attribute("repeat") == null ? 1
                : (int) splits.attributeStats(ATT_SPLITS_REPEAT).numericStats.max + 1;
        NR_OF_FOLDS = splits.attribute("fold") == null ? 1
                : (int) splits.attributeStats(ATT_SPLITS_FOLD).numericStats.max + 1;
        NR_OF_SAMPLES = splits.attribute("sample") == null ? 1
                : (int) splits.attributeStats(ATT_SPLITS_SAMPLE).numericStats.max + 1;

        expectedTotal = 0;
        expected = new ArrayList[NR_OF_REPEATS][NR_OF_FOLDS][NR_OF_SAMPLES];
        actual = new ArrayList[NR_OF_REPEATS][NR_OF_FOLDS][NR_OF_SAMPLES];
        shadowTypeSize = new int[NR_OF_REPEATS][NR_OF_FOLDS][NR_OF_SAMPLES];
        for (int i = 0; i < NR_OF_REPEATS; i++)
            for (int j = 0; j < NR_OF_FOLDS; j++)
                for (int k = 0; k < NR_OF_SAMPLES; k++) {
                    expected[i][j][k] = new ArrayList<Integer>();
                    actual[i][j][k] = new ArrayList<Integer>();
                }

        for (int i = 0; i < splits.numInstances(); i++) {
            Instance instance = splits.instance(i);
            if (instance.value(ATT_SPLITS_TYPE) == splits.attribute(ATT_SPLITS_TYPE).indexOfValue(type)) {
                int repeat = (int) instance.value(ATT_SPLITS_REPEAT);
                int fold = (int) instance.value(ATT_SPLITS_FOLD);
                int sample = ATT_SPLITS_SAMPLE < 0 ? 0 : (int) instance.value(ATT_SPLITS_SAMPLE);
                int rowid = (int) instance.value(ATT_SPLITS_ROWID); //TODO: maybe we need instance.stringValue() ...
                expected[repeat][fold][sample].add(rowid);
                expectedTotal++;
            } else if (instance.value(ATT_SPLITS_TYPE) == splits.attribute(ATT_SPLITS_TYPE)
                    .indexOfValue(shadowType)) {
                int repeat = (int) instance.value(ATT_SPLITS_REPEAT);
                int fold = (int) instance.value(ATT_SPLITS_FOLD);
                int sample = ATT_SPLITS_SAMPLE < 0 ? 0 : (int) instance.value(ATT_SPLITS_SAMPLE);

                shadowTypeSize[repeat][fold][sample]++;
            }
        }

        for (int i = 0; i < NR_OF_REPEATS; i++) {
            for (int j = 0; j < NR_OF_FOLDS; j++) {
                for (int k = 0; k < NR_OF_SAMPLES; k++) {
                    Collections.sort(expected[i][j][k]);
                }
            }
        }

        error_message = "";
    }

    public void addPrediction(int repeat, int fold, int sample, int rowid) {
        if (actual.length <= repeat)
            throw new RuntimeException("Repeat #" + repeat + " not defined by task. ");
        if (actual[repeat].length <= fold)
            throw new RuntimeException("Fold #" + fold + " not defined by task. ");
        actual[repeat][fold][sample].add(rowid);
    }

    public boolean check() {
        for (int i = 0; i < NR_OF_REPEATS; i++) {
            for (int j = 0; j < NR_OF_FOLDS; j++) {
                for (int k = 0; k < NR_OF_SAMPLES; k++) {
                    Collections.sort(actual[i][j][k]);
                    if (actual[i][j][k].equals(expected[i][j][k]) == false) {
                        error_message = "Repeat " + i + " fold " + j + " sample " + k
                                + " expected predictions with row id's " + expected[i][j][k]
                                + " , but got predictions with row id's " + actual[i][j][k];
                        return false;
                    }
                }
            }
        }
        return true;
    }

    public List<Integer> getExpectedRowids(int i, int j, int k) {
        return expected[i][j][k];
    }

    public int getShadowTypeSize(int i, int j, int k) {
        return shadowTypeSize[i][j][k];
    }

    public String getErrorMessage() {
        return error_message;
    }

    public int getRepeats() {
        return NR_OF_REPEATS;
    }

    public int getFolds() {
        return NR_OF_FOLDS;
    }

    public int getSamples() {
        return NR_OF_SAMPLES;
    }

    public int getExpectedTotal() {
        return expectedTotal;
    }
}