com.insightml.models.meta.VoteModel.java Source code

Java tutorial

Introduction

Here is the source code for com.insightml.models.meta.VoteModel.java

Source

/*
 * Copyright (C) 2016 Stefan Hen
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.insightml.models.meta;

import java.util.ArrayList;

import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

import com.insightml.data.samples.ISamples;
import com.insightml.data.samples.Sample;
import com.insightml.math.statistics.Stats;
import com.insightml.models.DistributionModel;
import com.insightml.models.DistributionPrediction;
import com.insightml.models.IModel;
import com.insightml.utils.Arrays;
import com.insightml.utils.jobs.AbstractJob;
import com.insightml.utils.jobs.IJobBatch;
import com.insightml.utils.jobs.ParallelFor;
import com.insightml.utils.jobs.ThreadedClient;

public final class VoteModel<I extends Sample> extends AbstractEnsembleModel<I, Double>
        implements DistributionModel<I> {

    private static final long serialVersionUID = -8515840219123634452L;

    public enum VoteStrategy {
        AVERAGE, MEDIAN, GEOMETRIC, HARMONIC
    }

    private VoteStrategy strategy;

    VoteModel() {
    }

    public VoteModel(final IModel<I, Double>[] models, final double[] weights, final VoteStrategy strategy) {
        super(models, weights);
        this.strategy = strategy;
    }

    @Override
    public Double[] apply(final ISamples<? extends I, ?> instnces) {
        final IModel<I, Double>[] models = getModels();
        final double[] weights = getWeights();
        final IJobBatch<Object> batch = new ThreadedClient().newBatch();
        final Double[][] predss = new Double[models.length][];
        for (int m = 0; m < models.length; ++m) {
            final int j = m;
            batch.addJob(new AbstractJob<Object>("") {

                private static final long serialVersionUID = -2963052506505226869L;

                @Override
                public Object run() {
                    predss[j] = models[j].apply(instnces);
                    return null;
                }
            });
        }
        batch.run();
        final DescriptiveStatistics[] map = Arrays.fill(instnces.size(), DescriptiveStatistics.class);
        for (int i = 0; i < predss.length; ++i) {
            for (int j = 0; j < predss[i].length; ++j) {
                map[j].addValue(predss[i][j] * weights[i]);
            }
        }
        final Double[] preds = new Double[map.length];
        for (int i = 0; i < preds.length; ++i) {
            preds[i] = resolve(map[i]);
        }
        return preds;
    }

    @Override
    public DistributionPrediction[] predictDistribution(final ISamples<? extends I, ?> instnces,
            final boolean debug) {
        final IModel<I, Double>[] models = getModels();
        final DistributionPrediction[] map = new DistributionPrediction[instnces.size()];
        for (int i = 0; i < map.length; ++i) {
            map[i] = new DistributionPrediction(new Stats(), new ArrayList<>());
        }
        for (final DistributionPrediction[] preds : ParallelFor.run(
                i -> ((DistributionModel<I>) models[i]).predictDistribution(instnces, debug), 0, models.length,
                1)) {
            for (int j = 0; j < preds.length; ++j) {
                map[j].add(preds[j]);
            }
        }
        return map;
    }

    private double resolve(final DescriptiveStatistics stats) {
        switch (strategy) {
        case AVERAGE:
            return stats.getMean();
        case MEDIAN:
            return stats.getPercentile(50);
        case GEOMETRIC:
            return stats.getGeometricMean();
        case HARMONIC:
            double sum = 0;
            for (final double value : stats.getValues()) {
                sum += 1 / value;
            }
            return stats.getN() * 1.0 / sum;
        default:
            throw new IllegalStateException();
        }
    }
}