weka.distributed.spark.WekaAttributeSelectionSparkJob.java Source code

Java tutorial

Introduction

Here is the source code for weka.distributed.spark.WekaAttributeSelectionSparkJob.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    WekaAttributeSelectionSparkJob
 *
 */

package weka.distributed.spark;

import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.Vector;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.storage.StorageLevel;

import scala.Tuple2;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.AggregateableEvaluation;
import weka.classifiers.evaluation.Evaluation;
import weka.core.Attribute;
import weka.core.CommandlineRunnable;
import weka.core.Debug;
import weka.core.Debug.SimpleLog;
import weka.core.Environment;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.converters.CSVSaver;
import weka.core.converters.Saver;
import weka.core.stats.ArffSummaryNumericMetric;
import weka.distributed.CSVToARFFHeaderMapTask;
import weka.distributed.CSVToARFFHeaderReduceTask;
import weka.distributed.DistributedWekaException;
import weka.distributed.WekaClassifierEvaluationMapTask;
import weka.distributed.WekaClassifierEvaluationReduceTask;
import weka.distributed.WekaClassifierMapTask;
import weka.distributed.WekaClassifierReduceTask;
import weka.filters.Filter;
import weka.filters.MakePreconstructedFilter;
import weka.filters.PreconstructedFilter;
import weka.filters.unsupervised.attribute.Remove;
import weka.gui.beans.InstancesProducer;
import weka.gui.beans.TextProducer;
import distributed.core.DistributedJob;
import distributed.core.DistributedJobConfig;

/**
 * Spark job for running an automated attribute selection process over a dataset. 
 * Modified from original source code of WekaClassifierEvaluationSparkJob by
 * Mark Hall.
 * 
 * @author Khoa DoBa
 * @version 1.0
 */
public class WekaAttributeSelectionSparkJob extends SparkJob
        implements TextProducer, InstancesProducer, CommandlineRunnable {

    /**
     * The subdirectory of the output directory that this job saves its results to
     */
    protected static final String OUTPUT_SUBDIR = "AttrSel";

    /** For serialization */
    private static final long serialVersionUID = 8099932783096025201L;

    /** Classifier job (used just for option parsing) */
    protected WekaClassifierSparkJob m_classifierJob = new WekaClassifierSparkJob();

    /** Textual evaluation results if job successful */
    protected String m_textEvalResults;

    /** Instances version of the evaluation results */
    protected Instances m_evalResults;

    protected SimpleLog m_debuglog;

    /**
     * Path to a separate test set (if not doing cross-validation or test on
     * training)
     */
    protected String m_separateTestSetPath = "";

    /**
     * Fraction of predictions to retain in order to compute auc/auprc.
     * Predictions are not retained if this is unspecified or the fraction is set
     * <= 0
     */
    protected String m_predFrac = "";

    /**
     * Optional user-supplied subdirectory of [output_dir]/eval in which to store
     * results
     */
    protected String m_optionalOutputSubDir = "";

    /**
     * Runs count for SLS search algorithm
     */
    protected String m_runsCount = "";

    /**
     * flip count for SLS search algorithm
     */
    protected String m_flipCount = "";

    /**
     * Greedy neighbors count for SLS search algorithm
     */
    protected String m_neighborsCount = "";

    /**
     * Random step chance for SLS search algorithm
     */
    protected String m_randomChance = "";

    /**
     * Constructor
     */
    public WekaAttributeSelectionSparkJob() {
        super("Weka attribute selection job", "Search for relevant attributes");
    }

    public static void main(String[] args) {
        WekaAttributeSelectionSparkJob wcesj = new WekaAttributeSelectionSparkJob();
        wcesj.run(wcesj, args);
    }

    /**
     * Help info for this job
     * 
     * @return help info for this job
     */
    public String globalInfo() {
        return "Search for relevant attributes using SLS search and Classifier Evaluation";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> result = new Vector<Option>();

        // the classifier job has all the options we need once the default
        // mapper and reducer have been replaced with the fold-based equivalents
        result.add(new Option("", "", 0, "\nNote: the -fold-number option is ignored by this job."));

        result.add(new Option("", "", 0, "\nOptions specific to model building and evaluation:"));

        result.add(new Option("\tPath to a separate test set. Set either this or\n\t"
                + "total-folds for a cross-validation (note that settting total-folds\n\t"
                + "to 1 will perform testing on training)", "test-set-path", 1, "-test-set-path <path>"));
        result.add(new Option(
                "\tCompute AUC and AUPRC. Note that this requires individual\n\t"
                        + "predictions to be retained - specify a fraction of\n\t"
                        + "predictions to sample (e.g. 0.5) in order to save resources.",
                "auc", 1, "-auc <fraction of predictions to sample>"));
        result.add(new Option("\tOptional sub-directory of <output-dir>/eval " + "in which to store results.",
                "output-subdir", 1, "-output-subdir <directory name>"));

        result.add(new Option("\t Number of SLS search runs from different random" + " starting locations",
                "runs-count", 1, "-runs-count <number>"));

        result.add(new Option("\t Number of flips for a search run", "flips-count", 1, "-flips-count <number>"));

        result.add(new Option("\t Number of neighbors for a search run", "neighbors-count", 1,
                "-neighbors-count <number>"));

        result.add(new Option("\t Random step chance for a search run", "random-chance", 1,
                "-random-chance <number>"));

        WekaClassifierSparkJob tempClassifierJob = new WekaClassifierSparkJob();

        Enumeration<Option> cOpts = tempClassifierJob.listOptions();
        while (cOpts.hasMoreElements()) {
            result.add(cOpts.nextElement());
        }

        return result.elements();
    }

    @Override
    public String[] getOptions() {
        List<String> options = new ArrayList<String>();

        for (String o : super.getOptions()) {
            options.add(o);
        }

        if (!DistributedJobConfig.isEmpty(getSeparateTestSetPath())) {
            options.add("-test-set-path");
            options.add(getSeparateTestSetPath());
        }

        if (!DistributedJobConfig.isEmpty(getSampleFractionForAUC())) {
            options.add("-auc");
            options.add(getSampleFractionForAUC());
        }

        if (!DistributedJobConfig.isEmpty(getOutputSubdir())) {
            options.add("-output-subdir");
            options.add(getOutputSubdir());
        }

        if (!DistributedJobConfig.isEmpty(getRunsCount())) {
            options.add("-runs-count");
            options.add(getRunsCount());
        }

        if (!DistributedJobConfig.isEmpty(getFlipCount())) {
            options.add("-flips-count");
            options.add(getFlipCount());
        }

        if (!DistributedJobConfig.isEmpty(getNeighborsCount())) {
            options.add("-neighbors-count");
            options.add(getNeighborsCount());
        }

        if (!DistributedJobConfig.isEmpty(getRandomChance())) {
            options.add("-random-chance");
            options.add(getRandomChance());
        }

        String[] classifierJobOpts = m_classifierJob.getOptions();
        for (String o : classifierJobOpts) {
            options.add(o);
        }

        return options.toArray(new String[options.size()]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {

        String separateTestSet = Utils.getOption("test-set-path", options);
        setSeparateTestSetPath(separateTestSet);

        String auc = Utils.getOption("auc", options);
        setSampleFractionForAUC(auc);

        String outputSubDir = Utils.getOption("output-subdir", options);
        setOutputSubdir(outputSubDir);

        String runsCount = Utils.getOption("runs-count", options);
        setRunsCount(runsCount);

        String flipsCount = Utils.getOption("flips-count", options);
        setFlipCount(flipsCount);

        String neighborsCount = Utils.getOption("neighbors-count", options);
        setNeighborsCount(neighborsCount);

        String randomChance = Utils.getOption("random-chance", options);
        setRandomChance(randomChance);

        String[] optionsCopy = options.clone();

        super.setOptions(options);

        m_classifierJob.setOptions(optionsCopy);
    }

    /**
     * Get the options pertaining to this job only
     * 
     * @return the options for this job only
     */
    public String[] getJobOptionsOnly() {
        List<String> options = new ArrayList<String>();

        if (!DistributedJobConfig.isEmpty(getSeparateTestSetPath())) {
            options.add("-test-set-path");
            options.add(getSeparateTestSetPath());
        }

        if (!DistributedJobConfig.isEmpty(getSampleFractionForAUC())) {
            options.add("-auc");
            options.add(getSampleFractionForAUC());
        }

        if (!DistributedJobConfig.isEmpty(getOutputSubdir())) {
            options.add("-output-subdir");
            options.add(getOutputSubdir());
        }

        if (!DistributedJobConfig.isEmpty(getRunsCount())) {
            options.add("-runs-count");
            options.add(getRunsCount());
        }

        if (!DistributedJobConfig.isEmpty(getFlipCount())) {
            options.add("-flips-count");
            options.add(getFlipCount());
        }

        if (!DistributedJobConfig.isEmpty(getNeighborsCount())) {
            options.add("-neighbors-count");
            options.add(getNeighborsCount());
        }

        if (!DistributedJobConfig.isEmpty(getRandomChance())) {
            options.add("-random-chance");
            options.add(getRandomChance());
        }

        return options.toArray(new String[options.size()]);
    }

    /**
     * Tip text for this property
     * 
     * @return the tip text for this property
     */
    public String separateTestSetPathTipText() {
        return "The path (in HDFS) to a separate test set to use";
    }

    /**
     * Get the path in HDFS to the separate test set to use. Either this or the
     * total number of folds should be specified (but not both).
     * 
     * @return the path in HDFS to the separate test set to evaluate on
     */
    public String getSeparateTestSetPath() {
        return m_separateTestSetPath;
    }

    /**
     * Set the path in to the separate test set to use. Either this or the total
     * number of folds should be specified (but not both).
     * 
     * @param path the path to the separate test set to evaluate on
     */
    public void setSeparateTestSetPath(String path) {
        m_separateTestSetPath = path;
    }

    /**
     * Tip text for this property
     * 
     * @return the tip text for this property
     */
    public String sampleFractionForAUCTipText() {
        return "The percentage of all predictions (randomly sampled) to retain for computing AUC "
                + "and AUPRC. If not specified, then these metrics are not computed and "
                + "no predictions are kept. " + "Use this option to keep the number of predictions retained under "
                + "control when computing AUC/AUPRC.";
    }

    /**
     * Get the percentage of predictions to retain (via uniform random sampling)
     * for computing AUC and AUPRC. If not specified, then no predictions are
     * retained and these metrics are not computed.
     * 
     * @return the fraction (between 0 and 1) of all predictions to retain for
     *         computing AUC/AUPRC.
     */
    public String getSampleFractionForAUC() {
        return m_predFrac;
    }

    /**
     * Set the percentage of predictions to retain (via uniform random sampling)
     * for computing AUC and AUPRC. If not specified, then no predictions are
     * retained and these metrics are not computed.
     * 
     * @param f the fraction (between 0 and 1) of all predictions to retain for
     *          computing AUC/AUPRC.
     */
    public void setSampleFractionForAUC(String f) {
        m_predFrac = f;
    }

    /**
     * Tool tip text for this property
     *
     * @return the tool tip text for this property
     */
    public String outputSubdirTipText() {
        return "An optional subdirectory of <output-dir>/eval in which to store the " + "results";
    }

    /**
     * Get an optional subdirectory of [output-dir]/eval in which to store results
     *
     * @return an optional subdirectory in the output directory for results
     */
    public String getOutputSubdir() {
        return m_optionalOutputSubDir;
    }

    /**
     * Set an optional subdirectory of [output-dir]/eval in which to store results
     *
     * @param subdir an optional subdirectory in the output directory for results
     */
    public void setOutputSubdir(String subdir) {
        m_optionalOutputSubDir = subdir;
    }

    /**
     * Get the number of search runs for SLS algorithm
     *
     * @return number of runs for search algorithm
     */
    public String getRunsCount() {
        return m_runsCount;
    }

    /**
     * Set the number of search runs for SLS algorithm
     *
     * @param number of runs for search algorithm
     */
    public void setRunsCount(String runsCount) {
        m_runsCount = runsCount;
    }

    /**
     * Get the number of flips for SLS algorithm
     *
     * @return number of flips for search algorithm
     */
    public String getFlipCount() {
        return m_flipCount;
    }

    /**
     * Set the number of flips for SLS algorithm
     *
     * @param number of flips for search algorithm
     */
    public void setFlipCount(String flipCount) {
        m_flipCount = flipCount;
    }

    /**
     * Get the number of neighbors for SLS algorithm
     *
     * @return number of flips for search algorithm
     */
    public String getNeighborsCount() {
        return m_neighborsCount;
    }

    /**
     * Set the number of neighbors for SLS algorithm
     *
     * @param number of neighbors for search algorithm
     */
    public void setNeighborsCount(String neighborsCount) {
        m_neighborsCount = neighborsCount;
    }

    /**
     * Get the random chance for SLS algorithm
     *
     * @return random chance float for search algorithm
     */
    public String getRandomChance() {
        return m_randomChance;
    }

    /**
     * Set the random chance for SLS algorithm
     *
     * @param random chance float for search algorithm
     */
    public void setRandomChance(String chance) {
        m_randomChance = chance;
    }

    protected Map<BitSet, Classifier[]> phaseOneBuildClassifiers(JavaPairRDD<BitSet, Iterable<Instance>> dataset,
            BitSet[] subsetList, final Instances headerNoSummary) throws Exception {

        int totalFolds = 1;
        final String classifierMapTaskOptions = environmentSubstitute(
                m_classifierJob.getClassifierMapTaskOptions());
        String[] cOpts = Utils.splitOptions(classifierMapTaskOptions);
        String numFolds = Utils.getOption("total-folds", cOpts.clone());
        final boolean forceVote = Utils.getFlag("force-vote", cOpts.clone());
        if (!DistributedJobConfig.isEmpty(numFolds)) {
            totalFolds = Integer.parseInt(numFolds);
        }
        final int tFolds = totalFolds;

        final Map<BitSet, Classifier[]> foldClassifiers = new HashMap<BitSet, Classifier[]>();
        for (BitSet subset : subsetList) {
            foldClassifiers.put(subset, new Classifier[totalFolds]);
        }

        // just use headerNoSummary for class index
        final int classIndex = headerNoSummary.classIndex();
        final int numPartitions = dataset.partitions().size();

        int numIterations = m_classifierJob.getNumIterations();

        final int numSplits = dataset.partitions().size();

        for (int i = 0; i < numIterations; i++) {
            final int iterationNum = i;
            logMessage("[WekaClassifierEvaluation] Phase 1 (map), iteration " + (i + 1));

            JavaPairRDD<Tuple2<BitSet, Integer>, Classifier> mapFolds = dataset.flatMapToPair(
                    new PairFlatMapFunction<Tuple2<BitSet, Iterable<Instance>>, Tuple2<BitSet, Integer>, Classifier>() {

                        /** For serialization */
                        private static final long serialVersionUID = -1906414304952140395L;

                        protected Instances m_header;

                        /** Holds results */
                        protected List<Tuple2<Tuple2<BitSet, Integer>, Classifier>> m_classifiersForFolds = new ArrayList<Tuple2<Tuple2<BitSet, Integer>, Classifier>>();

                        //
                        //         @Override
                        //         public Tuple2<Integer, Classifier> call(
                        //               Tuple2<PreconstructedFilter, Iterable<Instance>> arg0)
                        //               throws Exception {
                        //            // TODO Auto-generated method stub
                        //            return null;
                        //         }

                        @Override
                        public Iterable<Tuple2<Tuple2<BitSet, Integer>, Classifier>> call(
                                Tuple2<BitSet, Iterable<Instance>> arg0)
                                throws IOException, DistributedWekaException {

                            PreconstructedFilter preconstructedFilter = GetFilterFromBitSet(arg0._1(),
                                    headerNoSummary);
                            Iterator<Instance> split = arg0._2().iterator();

                            Instance current = split.next();
                            if (current == null) {
                                throw new IOException("No data in this partition!!");
                            }

                            m_header = current.dataset();
                            m_header.setClassIndex(classIndex);
                            // WekaClassifierMapTask tempTask = new WekaClassifierMapTask();
                            // try {
                            // WekaClassifierSparkJob.configureClassifierMapTask(tempTask,
                            // null, classifierMapTaskOptions, iterationNum,
                            // preconstructedFilter, numSplits);
                            // } catch (Exception ex) {
                            // throw new DistributedWekaException(ex);
                            // }
                            //
                            // boolean isUpdateableClassifier = tempTask.getClassifier()
                            // instanceof UpdateableClassifier;
                            // boolean forceBatchForUpdateable =
                            // tempTask.getForceBatchLearningForUpdateableClassifiers();

                            WekaClassifierMapTask[] tasks = new WekaClassifierMapTask[tFolds];
                            for (int j = 0; j < tFolds; j++) {
                                try {
                                    tasks[j] = new WekaClassifierMapTask();
                                    WekaClassifierSparkJob.configureClassifierMapTask(tasks[j],
                                            foldClassifiers.get(arg0._1())[j], classifierMapTaskOptions,
                                            iterationNum, preconstructedFilter, numSplits);

                                    // set fold number and total folds
                                    tasks[j].setFoldNumber(j + 1);
                                    tasks[j].setTotalNumFolds(tFolds);
                                    Environment env = new Environment();
                                    env.addVariable(WekaClassifierMapTask.TOTAL_NUMBER_OF_MAPS, "" + numPartitions);
                                    tasks[j].setEnvironment(env);
                                } catch (Exception ex) {
                                    logMessage(ex);
                                    throw new DistributedWekaException(ex);
                                }

                                // initialize
                                tasks[j].setup(headerNoSummary);
                            }

                            while (split.hasNext()) {
                                current = split.next();

                                for (int j = 0; j < tFolds; j++) {
                                    tasks[j].processInstance(current);
                                }
                            }

                            for (int j = 0; j < tFolds; j++) {
                                tasks[j].finalizeTask();
                                m_classifiersForFolds.add(new Tuple2<Tuple2<BitSet, Integer>, Classifier>(
                                        new Tuple2<BitSet, Integer>(arg0._1(), j), tasks[j].getClassifier()));
                            }

                            return m_classifiersForFolds;
                        }

                    });
            mapFolds = mapFolds.persist(StorageLevel.MEMORY_AND_DISK());
            // memory and disk here for fast access and to avoid
            // recomputing partial classifiers if all partial classifiers
            // can't fit in memory

            // reduce fold models
            logMessage("[WekaClassifierEvaluation] Phase 1 (reduce), iteration " + (i + 1));
            JavaPairRDD<Tuple2<BitSet, Integer>, Classifier> reducedByFold = mapFolds.groupByKey().mapToPair(
                    new PairFunction<Tuple2<Tuple2<BitSet, Integer>, Iterable<Classifier>>, Tuple2<BitSet, Integer>, Classifier>() {
                        /** For serialization */
                        private static final long serialVersionUID = 2481672301097842496L;

                        @Override
                        public Tuple2<Tuple2<BitSet, Integer>, Classifier> call(
                                Tuple2<Tuple2<BitSet, Integer>, Iterable<Classifier>> arg0)
                                throws Exception, DistributedWekaException {

                            Iterator<Classifier> split = arg0._2().iterator();
                            //            
                            //              int foldNum = -1;
                            //
                            List<Classifier> classifiers = new ArrayList<Classifier>();

                            while (split.hasNext()) {
                                classifiers.add(split.next());
                            }
                            //                Tuple2<Integer, Classifier> partial = split.next();
                            //                if (foldNum < 0) {
                            //                  foldNum = partial._1().intValue();
                            //                } else {
                            //                  if (partial._1().intValue() != foldNum) {
                            //                    throw new DistributedWekaException(
                            //                      "[WekaClassifierEvaluation] build "
                            //                        + "classifiers reduce phase: was not expecting fold number "
                            //                        + "to change within a partition!");
                            //                  }
                            //                }
                            //                classifiers.add(partial._2());
                            //              }

                            WekaClassifierReduceTask reduceTask = new WekaClassifierReduceTask();
                            Classifier intermediateClassifier = reduceTask.aggregate(classifiers, null, forceVote);

                            return new Tuple2<Tuple2<BitSet, Integer>, Classifier>(arg0._1(),
                                    intermediateClassifier);
                        }

                    });

            List<Tuple2<Tuple2<BitSet, Integer>, Classifier>> aggregated = reducedByFold.collect();
            for (Tuple2<Tuple2<BitSet, Integer>, Classifier> t : aggregated) {
                // this makes my head hurts!
                foldClassifiers.get(t._1()._1())[t._1()._2()] = t._2();
            }

            mapFolds.unpersist();
            reducedByFold.unpersist();
        }

        return foldClassifiers;
    }

    protected List<Tuple2<BitSet, Evaluation>> phaseTwoEvaluateClassifiers(
            JavaPairRDD<BitSet, Iterable<Instance>> dataSet, final Instances headerWithSummary,
            final Instances headerNoSummary, final Map<BitSet, Classifier[]> foldClassifiersMap) throws Exception {

        int totalFolds = 1;
        final String classifierMapTaskOptions = environmentSubstitute(
                m_classifierJob.getClassifierMapTaskOptions());
        String[] cOpts = Utils.splitOptions(classifierMapTaskOptions);
        String numFolds = Utils.getOption("total-folds", cOpts);
        if (!DistributedJobConfig.isEmpty(numFolds)) {
            totalFolds = Integer.parseInt(numFolds);
        }
        final int tFolds = totalFolds;
        final boolean forceBatch = Utils.getFlag("force-batch", cOpts);
        String sSeed = Utils.getOption("seed", cOpts);
        long seed = 1L;
        if (!DistributedJobConfig.isEmpty(sSeed)) {
            try {
                sSeed = m_env.substitute(sSeed);
            } catch (Exception ex) {
            }

            try {
                seed = Long.parseLong(sSeed);
            } catch (NumberFormatException ex) {
            }
        }
        final long fseed = seed;

        // an sample size > 0 indicates that we will be retaining
        // predictions in order to compute auc/auprc
        String predFracS = getSampleFractionForAUC();
        double predFrac = 0;
        if (!DistributedJobConfig.isEmpty(predFracS)) {
            try {
                predFrac = Double.parseDouble(predFracS);
            } catch (NumberFormatException ex) {
                System.err.println("Unable to parse the fraction of predictions to retain: " + predFracS);
            }
        }
        final double fpredFrac = predFrac;

        Attribute classAtt = headerNoSummary.classAttribute();
        String classAttSummaryName = CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX + classAtt.name();
        Attribute summaryClassAtt = headerWithSummary.attribute(classAttSummaryName);
        if (summaryClassAtt == null) {
            throw new DistributedWekaException("[WekaClassifierEvaluation] evaluate "
                    + "classifiers: was unable to find the summary metadata attribute for the "
                    + "class attribute in the header");
        }

        double priorsCount = 0;
        double[] priors = new double[classAtt.isNominal() ? classAtt.numValues() : 1];
        if (classAtt.isNominal()) {
            for (int i = 0; i < classAtt.numValues(); i++) {
                String label = classAtt.value(i);
                String labelWithCount = summaryClassAtt.value(i).replace(label + "_", "").trim();

                try {
                    priors[i] = Double.parseDouble(labelWithCount);
                } catch (NumberFormatException n) {
                    throw new Exception(n);
                }
            }

            priorsCount = classAtt.numValues();
        } else {

            double count = ArffSummaryNumericMetric.COUNT.valueFromAttribute(summaryClassAtt);
            double sum = ArffSummaryNumericMetric.SUM.valueFromAttribute(summaryClassAtt);

            priors[0] = sum;
            priorsCount = count;
        }
        final double[] fpriors = priors;
        final double fpriorsCount = priorsCount;

        // map phase
        logMessage("[WekaClassifierEvaluation] Phase 2 (map)");
        JavaPairRDD<BitSet, Iterable<Evaluation>> mapFolds = dataSet
                .mapToPair(new PairFunction<Tuple2<BitSet, Iterable<Instance>>, BitSet, Iterable<Evaluation>>() {

                    /** For serialization */
                    private static final long serialVersionUID = 5800617408839460876L;

                    protected List<Evaluation> m_evaluationForPartition = new ArrayList<Evaluation>();

                    @Override
                    public Tuple2<BitSet, Iterable<Evaluation>> call(Tuple2<BitSet, Iterable<Instance>> arg0)
                            throws IOException, DistributedWekaException {

                        Iterator<Instance> split = arg0._2().iterator();
                        Classifier[] foldClassifiers = foldClassifiersMap.get(arg0._1());

                        // setup base tasks
                        WekaClassifierEvaluationMapTask[] evalTasks = new WekaClassifierEvaluationMapTask[tFolds];
                        for (int i = 0; i < tFolds; i++) {
                            evalTasks[i] = new WekaClassifierEvaluationMapTask();
                            evalTasks[i].setClassifier(foldClassifiers[i]);
                            evalTasks[i].setTotalNumFolds(tFolds);
                            evalTasks[i].setFoldNumber(i + 1);
                            evalTasks[i].setBatchTrainedIncremental(forceBatch);
                            try {
                                evalTasks[i].setup(headerNoSummary, fpriors, fpriorsCount, fseed, fpredFrac);
                            } catch (Exception ex) {
                                throw new DistributedWekaException(ex);
                            }
                        }

                        try {
                            while (split.hasNext()) {
                                Instance current = split.next();
                                for (WekaClassifierEvaluationMapTask t : evalTasks) {
                                    t.processInstance(current);
                                }
                            }

                            AggregateableEvaluation agg = null;
                            // finalize
                            for (int i = 0; i < tFolds; i++) {
                                evalTasks[i].finalizeTask();
                                Evaluation eval = evalTasks[i].getEvaluation();

                                // save memory
                                evalTasks[i] = null;
                                foldClassifiers[i] = null;

                                if (agg == null) {
                                    agg = new AggregateableEvaluation(eval);
                                }
                                agg.aggregate(eval);
                            }

                            if (agg != null) {
                                m_evaluationForPartition.add(agg);
                            }

                        } catch (Exception ex) {
                            throw new DistributedWekaException(ex);
                        }

                        return new Tuple2<BitSet, Iterable<Evaluation>>(arg0._1(), m_evaluationForPartition);
                    }

                });

        // reduce locally
        logMessage("[WekaClassifierEvaluation] Phase 2 (reduce)");
        JavaPairRDD<BitSet, Iterable<Evaluation>> mapFoldsReduced = mapFolds
                .reduceByKey(new Function2<Iterable<Evaluation>, Iterable<Evaluation>, Iterable<Evaluation>>() {

                    /**
                     * 
                     */
                    private static final long serialVersionUID = -6011262863159057209L;

                    @Override
                    public Iterable<Evaluation> call(Iterable<Evaluation> arg0, Iterable<Evaluation> arg1)
                            throws Exception {

                        List<Evaluation> returnEvals = new ArrayList<Evaluation>();
                        Iterator<Evaluation> left = arg0.iterator();
                        Iterator<Evaluation> right = arg0.iterator();

                        while (left.hasNext() && right.hasNext()) {
                            AggregateableEvaluation aggEval = new AggregateableEvaluation(left.next());
                            aggEval.aggregate(right.next());

                            returnEvals.add(aggEval);
                        }

                        return returnEvals;

                    }

                });
        JavaRDD<Tuple2<BitSet, Evaluation>> aggEval = mapFoldsReduced
                .map(new Function<Tuple2<BitSet, Iterable<Evaluation>>, Tuple2<BitSet, Evaluation>>() {

                    /**
                     * 
                     */
                    private static final long serialVersionUID = -4122870509656187814L;

                    @Override
                    public Tuple2<BitSet, Evaluation> call(Tuple2<BitSet, Iterable<Evaluation>> arg0)
                            throws Exception {

                        Iterator<Evaluation> evals = arg0._2().iterator();

                        AggregateableEvaluation aggEval = new AggregateableEvaluation(evals.next());
                        while (evals.hasNext()) {
                            aggEval.aggregate(evals.next());
                        }

                        return new Tuple2<BitSet, Evaluation>(arg0._1(), aggEval);
                    }

                });
        return aggEval.collect();
    }

    @Override
    public boolean runJobWithContext(JavaSparkContext sparkContext) throws IOException, DistributedWekaException {

        m_currentContext = sparkContext;
        boolean success;
        setJobStatus(JobStatus.RUNNING);

        m_debuglog = new Debug.SimpleLog();

        if (m_env == null) {
            m_env = Environment.getSystemWide();
        }

        try {
            // Make sure that we save out to a subdirectory of the output
            // directory
            String outputPath = environmentSubstitute(m_sjConfig.getOutputDir());
            outputPath = addSubdirToPath(outputPath, OUTPUT_SUBDIR);
            if (!DistributedJobConfig.isEmpty(m_optionalOutputSubDir)) {
                outputPath = addSubdirToPath(outputPath, environmentSubstitute(m_optionalOutputSubDir));
            }

            String classifierMapTaskOptions = environmentSubstitute(m_classifierJob.getClassifierMapTaskOptions());
            String[] cOpts = Utils.splitOptions(classifierMapTaskOptions);
            int totalFolds = 1;
            String numFolds = Utils.getOption("total-folds", cOpts.clone());
            if (!DistributedJobConfig.isEmpty(numFolds)) {
                totalFolds = Integer.parseInt(numFolds);
            }

            if (totalFolds > 1 && !DistributedJobConfig.isEmpty(getSeparateTestSetPath())) {
                throw new DistributedWekaException("Total folds is > 1 and a separate test set "
                        + "has been specified - can only perform one or the other out "
                        + "of a cross-validation or separate test set evaluation");
            }

            String seed = Utils.getOption("seed", cOpts);

            if (totalFolds < 1) {
                throw new DistributedWekaException("Total folds can't be less than 1!");
            }

            JavaRDD<Instance> dataSet = null;
            Instances headerWithSummary = null;
            if (getDataset(TRAINING_DATA) != null) {
                dataSet = getDataset(TRAINING_DATA).getDataset();
                headerWithSummary = getDataset(TRAINING_DATA).getHeaderWithSummary();
                logMessage("RDD<Instance> dataset provided: " + dataSet.partitions().size() + " partitions.");
            }

            if (dataSet == null && headerWithSummary == null) {
                // Run the ARFF job if necessary
                logMessage("Invoking ARFF Job...");
                m_classifierJob.m_arffHeaderJob.setEnvironment(m_env);
                m_classifierJob.m_arffHeaderJob.setLog(getLog());
                m_classifierJob.m_arffHeaderJob.setStatusMessagePrefix(m_statusMessagePrefix);
                m_classifierJob.m_arffHeaderJob.setCachingStrategy(getCachingStrategy());
                success = m_classifierJob.m_arffHeaderJob.runJobWithContext(sparkContext);

                if (!success) {
                    setJobStatus(JobStatus.FAILED);
                    statusMessage("Unable to continue - creating the ARFF header failed!");
                    logMessage("Unable to continue - creating the ARFF header failed!");
                    return false;
                }

                Dataset d = m_classifierJob.m_arffHeaderJob.getDataset(TRAINING_DATA);
                headerWithSummary = d.getHeaderWithSummary();
                dataSet = d.getDataset();
                setDataset(TRAINING_DATA, d);
                logMessage("Fetching RDD<Instance> dataset from ARFF job: " + dataSet.partitions().size()
                        + " partitions.");
            }

            Instances headerNoSummary = CSVToARFFHeaderReduceTask.stripSummaryAtts(headerWithSummary);
            String classAtt = "";
            if (!DistributedJobConfig.isEmpty(m_classifierJob.getClassAttribute())) {
                classAtt = environmentSubstitute(m_classifierJob.getClassAttribute());
            }
            WekaClassifierSparkJob.setClassIndex(classAtt, headerNoSummary, true);

            if (m_classifierJob.getRandomizeAndStratify()
            /* && !m_classifierJob.getSerializedInput() */) {
                m_classifierJob.m_randomizeSparkJob.setEnvironment(m_env);
                m_classifierJob.m_randomizeSparkJob.setDefaultToLastAttIfClassNotSpecified(true);
                m_classifierJob.m_randomizeSparkJob.setStatusMessagePrefix(m_statusMessagePrefix);
                m_classifierJob.m_randomizeSparkJob.setLog(getLog());
                m_classifierJob.m_randomizeSparkJob.setCachingStrategy(getCachingStrategy());
                m_classifierJob.m_randomizeSparkJob.setDataset(TRAINING_DATA,
                        new Dataset(dataSet, headerWithSummary));

                // make sure the random seed gets in there from the setting in the
                // underlying
                // classifier map task
                try {
                    String[] classifierOpts = Utils.splitOptions(classifierMapTaskOptions);
                    String seedS = Utils.getOption("seed", classifierOpts);
                    if (!DistributedJobConfig.isEmpty(seedS)) {
                        seedS = environmentSubstitute(seedS);
                        m_classifierJob.m_randomizeSparkJob.setRandomSeed(seedS);
                    }
                } catch (Exception ex) {
                    logMessage(ex);
                    ex.printStackTrace();
                }

                if (!m_classifierJob.m_randomizeSparkJob.runJobWithContext(sparkContext)) {
                    statusMessage("Unable to continue - randomization/stratification of input data failed!");
                    logMessage("Unable to continue - randomization/stratification of input data failed!");
                    return false;
                }

                Dataset d = m_classifierJob.m_randomizeSparkJob.getDataset(TRAINING_DATA);
                dataSet = d.getDataset();
                headerWithSummary = d.getHeaderWithSummary();
                setDataset(TRAINING_DATA, d);
            }
            // clean the output directory
            SparkJob.deleteDirectory(outputPath);

            // separate test set?
            JavaRDD<Instance> dataset = null;
            if (!DistributedJobConfig.isEmpty(getSeparateTestSetPath())) {
                // dataset.unpersist();

                int minSlices = 1;
                if (!DistributedJobConfig.isEmpty(m_sjConfig.getMinInputSlices())) {
                    try {
                        minSlices = Integer.parseInt(environmentSubstitute(m_sjConfig.getMinInputSlices()));
                    } catch (NumberFormatException e) {
                    }
                }
                String separateTestSetS = environmentSubstitute(getSeparateTestSetPath());
                dataset = loadInput(separateTestSetS, headerNoSummary, m_classifierJob.getCSVMapTaskOptions(),
                        sparkContext, getCachingStrategy(), minSlices, true);

                setDataset(TEST_DATA, new Dataset(dataset, headerWithSummary));
            }

            // parse SLS algorithm arguments
            int maxTries = 1;
            int maxFlips = 6;
            int greedyNeighbors = 4;
            float noise = 0.2f;

            if (!DistributedJobConfig.isEmpty(getRunsCount())) {
                logMessage("Runs Count = " + getRunsCount());
                maxTries = Integer.parseInt(getRunsCount());
            }

            if (!DistributedJobConfig.isEmpty(getFlipCount())) {
                logMessage("Flip Count = " + getFlipCount());
                maxFlips = Integer.parseInt(getFlipCount());
            }

            if (!DistributedJobConfig.isEmpty(getNeighborsCount())) {
                logMessage("Neighbors Count = " + getNeighborsCount());
                greedyNeighbors = Integer.parseInt(getNeighborsCount());
            }

            if (!DistributedJobConfig.isEmpty(getRandomChance())) {
                logMessage("Random Chance = " + getRandomChance());
                noise = Float.parseFloat(getRandomChance());
            }

            final Random RAND = new Random();
            int featureCount = headerNoSummary.numAttributes() - 1;
            int MAX_NEIGHBOR_TRIES = 100;

            Double maxGoal = Double.NEGATIVE_INFINITY; // performance score record while o
            BitSet optimalState = null;

            Set<BitSet> visitedStates = new HashSet<BitSet>();

            for (int t = 0; t < maxTries; t++) {
                BitSet startNode = this.generateInitialState(featureCount, RAND);

                JavaRDD<Instance> finalDataSet = dataSet;
                Instances finalHeaderNoSummary = headerNoSummary;
                Instances finalHeaderWithSummary = headerWithSummary;

                BitSet currentState = startNode;

                BitSet[] states = new BitSet[1];
                states[0] = startNode;

                Tuple2<Double, BitSet> result = EvaluateSubset(states, finalDataSet, finalHeaderNoSummary,
                        finalHeaderWithSummary); // calculate performance i
                if (result._1() > maxGoal) {
                    maxGoal = result._1(); // record performance
                    optimalState = result._2(); // record feature set end
                }
                visitedStates.add(states[0]);// add to taboo list1;

                for (int r = 0; r < maxFlips; r++) {
                    boolean doNoiseStep = RAND.nextFloat() < noise;
                    if (doNoiseStep) {
                        // get neighbor
                        BitSet state = (BitSet) currentState.clone();
                        for (int i = 0; i < MAX_NEIGHBOR_TRIES; i++) {
                            int index = RAND.nextInt(featureCount);
                            state.flip(index);
                            if (!visitedStates.contains(state)) {
                                break;
                            } else {
                                state.flip(index);
                            }
                        }

                        logMessage("Noise step, state = " + state.toString());

                        if (state != null) {
                            BitSet[] stateList = new BitSet[1];
                            stateList[0] = state;
                            Tuple2<Double, BitSet> goal = EvaluateSubset(stateList, finalDataSet,
                                    finalHeaderNoSummary, finalHeaderWithSummary);
                            visitedStates.add(state);
                            if (goal._1() > maxGoal) {
                                maxGoal = goal._1(); // record performance
                                optimalState = goal._2(); // record feature set end
                            }

                            currentState = state;
                        }
                    } else {
                        Tuple2<Double, BitSet> neighborOptimalState = null;

                        BitSet[] neighborStates = new BitSet[greedyNeighbors];
                        for (int k = 0; k < greedyNeighbors; k++) {

                            // get neighbor
                            BitSet neighborState = (BitSet) currentState.clone();
                            for (int i = 0; i < MAX_NEIGHBOR_TRIES; i++) {
                                int index = RAND.nextInt(featureCount);
                                neighborState.flip(index);
                                if (!visitedStates.contains(neighborState)) {
                                    break;
                                } else {
                                    neighborState.flip(index);
                                }
                            }

                            visitedStates.add(neighborState);
                            neighborStates[k] = neighborState;

                        }

                        logMessage("greedy step, neighborStates = ");
                        for (BitSet n : neighborStates) {
                            logMessage(n.toString());
                        }

                        if (neighborStates != null) {
                            neighborOptimalState = EvaluateSubset(neighborStates, finalDataSet,
                                    finalHeaderNoSummary, finalHeaderWithSummary);
                        }

                        currentState = neighborOptimalState._2();
                        if (currentState == null) {
                            logMessage("current stage is null");
                        }
                        if (neighborOptimalState._1() > maxGoal) {
                            maxGoal = neighborOptimalState._1();
                            optimalState = neighborOptimalState._2();
                        }
                        logMessage("greedy step done, next state = " + currentState.toString());
                    }
                }

            }

            logMessage("all steps done, optimal state = " + optimalState.toString());

            storeAndWriteEvalResults(optimalState, headerNoSummary, outputPath);

        } catch (Exception ex) {
            logMessage(ex);
            throw new DistributedWekaException(ex);
        }
        setJobStatus(JobStatus.FINISHED);

        return true;
    }

    protected PreconstructedFilter GetFilterFromBitSet(BitSet subset, Instances headerNoSummary) {
        Remove f = new Remove();

        f.setInvertSelection(true);

        int numAttribs = headerNoSummary.numAttributes();
        int classIndex = headerNoSummary.classIndex();
        int numFilteredAttribs = 0;

        for (int i = 0; i < numAttribs; i++) {
            if (subset.get(i) && i != classIndex) {
                numFilteredAttribs++;
            }
        }

        if (numFilteredAttribs <= 0)
            return null;

        // set up an array of attribute indexes for the filter
        int[] featureArray = new int[numFilteredAttribs + 1];
        int i, j;

        for (i = 0, j = 0; i < numAttribs; i++) {
            if (subset.get(i) && i != classIndex) {
                featureArray[j++] = i;
            }
        }

        featureArray[j] = classIndex;

        f.setAttributeIndicesArray(featureArray);

        return new MakePreconstructedFilter(f);
    }

    protected Tuple2<Double, BitSet> EvaluateSubset(BitSet[] subsetList, JavaRDD<Instance> dataSet,
            Instances headerNoSummary, Instances headerWithSummary) throws Exception {
        //     // this is for performance evaluation
        //     int index = 0;
        //     for (BitSet subset : subsetList)
        //     {
        //        subset.clear();
        //        subset.flip(0,20);
        //        subset.flip(index++);
        //     }

        logMessage("Evaluate Subsets: ");
        for (BitSet n : subsetList) {
            logMessage(n.toString());
        }

        final BitSet[] finalSubsets = subsetList;

        JavaPairRDD<BitSet, Instance> bitsetInstanceData = dataSet
                .mapPartitionsToPair(new PairFlatMapFunction<Iterator<Instance>, BitSet, Instance>() {

                    /**
                     * 
                     */
                    private static final long serialVersionUID = -7672702819274482635L;

                    @Override
                    public Iterable<Tuple2<BitSet, Instance>> call(Iterator<Instance> split) throws Exception {

                        List<Tuple2<BitSet, Instance>> returnValue = new ArrayList<Tuple2<BitSet, Instance>>();

                        while (split.hasNext()) {
                            Instance current = split.next();

                            for (BitSet s : finalSubsets) {
                                returnValue.add(new Tuple2<BitSet, Instance>(s, current));
                            }
                        }

                        return returnValue;
                    }

                }, true);

        JavaPairRDD<BitSet, Iterable<Instance>> groupedData = bitsetInstanceData.groupByKey();

        Map<BitSet, Classifier[]> foldClassifiers = phaseOneBuildClassifiers(groupedData, subsetList,
                headerNoSummary);

        logMessage("Phase 1 done");
        for (Map.Entry<BitSet, Classifier[]> entry : foldClassifiers.entrySet()) {
            logMessage("Bitset: " + entry.getKey().toString() + " Classifiers: " + entry.getValue().length);
        }

        List<Tuple2<BitSet, Evaluation>> results = phaseTwoEvaluateClassifiers(groupedData, headerWithSummary,
                headerNoSummary, foldClassifiers);

        // get best result
        Double bestEval = Double.NEGATIVE_INFINITY;
        BitSet bestSubset = null;

        logMessage("Result count = " + results.size());

        for (Tuple2<BitSet, Evaluation> result : results) {
            Double eval = Double.MIN_VALUE;
            if (headerNoSummary.classAttribute().isNominal()) {
                eval = -result._2().errorRate();
            } else {
                eval = -result._2().meanAbsoluteError();
            }

            if (eval.isNaN())
                eval = (double) -100000000;

            logMessage("Result = " + result._1().toString() + " ==> " + eval);

            if (eval > bestEval) {
                bestEval = eval;
                bestSubset = result._1();
            }
        }

        logMessage("Best Result = " + bestSubset.toString() + " ==> " + bestEval);

        return new Tuple2<Double, BitSet>(bestEval, bestSubset);
    }

    protected BitSet generateInitialState(int featureCount, Random RAND) {
        BitSet newState = new BitSet(featureCount);
        for (int i = 0; i < featureCount; ++i) {
            if (RAND.nextFloat() < 0.5)
                newState.flip(i);
        }

        return newState;
    }

    /**
     * Stores the results and writes them to the output path
     * 
     * @param aggregated the final Evaluation object
     * @param headerNoSummary the header of the training data without summary
     *          attributes
     * @param totalFolds the total number of folds used in the evaluation
     * @param seed the random number seed used for shuffling and fold creation
     * @param outputPath the output path
     * @throws IOException if a problem occurs
     */
    protected void storeAndWriteEvalResults(BitSet state, Instances headerNoSummary, String outputPath)
            throws IOException {

        StringBuilder buff = new StringBuilder();
        String info = "Summary - ";

        info += ":\n";

        buff.append(state.toString());

        int classIndex = headerNoSummary.classIndex();
        int numAttribs = headerNoSummary.numAttributes();

        buff.append("Label: " + headerNoSummary.classAttribute().toString() + "\n");
        buff.append("Selected Attributes: \n");
        for (int i = 0; i < numAttribs; i++) {
            if (state.get(i) && i != classIndex) {
                buff.append(headerNoSummary.attribute(i).toString() + "\n");
            }
        }

        String evalOutputPath = outputPath + (outputPath.toLowerCase().contains("://") ? "/" : File.separator)
                + "evaluation.txt";
        m_textEvalResults = buff.toString();
        PrintWriter writer = null;
        try {
            writer = openTextFileForWrite(evalOutputPath);
            writer.println(m_textEvalResults);
        } finally {
            if (writer != null) {
                writer.flush();
                writer.close();
                writer = null;
            }
        }
        //
        //    OutputStream stream = null;
        //    try {
        //      Instances asInstances =
        //        WekaClassifierEvaluationReduceTask
        //          .evaluationResultsToInstances(bestEval);
        //
        //      String arffOutputPath =
        //        outputPath
        //          + (outputPath.toLowerCase().contains("://") ? "/" : File.separator)
        //          + "evaluation.arff";
        //      writer = openTextFileForWrite(arffOutputPath);
        //      writer.println(asInstances.toString());
        //
        //      String csvOutputPath =
        //        outputPath
        //          + (outputPath.toLowerCase().contains("://") ? "/" : File.separator)
        //          + "evaluation.csv";
        //      stream = openFileForWrite(csvOutputPath);
        //      CSVSaver saver = new CSVSaver();
        //      saver.setRetrieval(Saver.BATCH);
        //      saver.setInstances(asInstances);
        //      saver.setDestination(stream);
        //      saver.writeBatch();
        //    } catch (Exception ex) {
        //      logMessage(ex);
        //      throw new IOException(ex);
        //    } finally {
        //      if (writer != null) {
        //        writer.flush();
        //        writer.close();
        //      }
        //      if (stream != null) {
        //        stream.flush();
        //        stream.close();
        //      }
        //    }
    }

    @Override
    public Instances getInstances() {
        return m_evalResults;
    }

    @Override
    public String getText() {
        return m_textEvalResults;
    }

    @Override
    public void run(Object toRun, String[] options) throws IllegalArgumentException {
        if (!(toRun instanceof WekaAttributeSelectionSparkJob)) {
            throw new IllegalArgumentException("Object to run is not a WekaClassifierEvaluationSparkJob!");
        }

        try {
            WekaAttributeSelectionSparkJob job = (WekaAttributeSelectionSparkJob) toRun;
            if (Utils.getFlag('h', options)) {
                String help = DistributedJob.makeOptionsStr(job);
                System.err.println(help);
                System.exit(1);
            }

            job.setOptions(options);
            job.runJob();
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}