weka.distributed.spark.RandomizedDataSparkJob.java Source code

Java tutorial

Introduction

Here is the source code for weka.distributed.spark.RandomizedDataSparkJob.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    RandomizedDataSparkJob
 *    Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.distributed.spark;

import distributed.core.DistributedJob;
import distributed.core.DistributedJobConfig;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;
import weka.core.Attribute;
import weka.core.CommandlineRunnable;
import weka.core.Environment;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.stats.ArffSummaryNumericMetric;
import weka.core.stats.NominalStats;
import weka.core.stats.NumericStats;
import weka.distributed.CSVToARFFHeaderMapTask;
import weka.distributed.CSVToARFFHeaderReduceTask;
import weka.distributed.DistributedWekaException;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Vector;

/**
 * Job for creating randomly shuffled (and stratified if a nominal class is set)
 * RDDs. Can also persist the shuffled splits as either binary object files or
 * CSV.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision$
 */
public class RandomizedDataSparkJob extends SparkJob implements CommandlineRunnable {

    /** The name of the subdir in the output directory for this job */
    public static final String OUTPUT_SUBDIR = "randomized";

    /** For serialization */
    private static final long serialVersionUID = -8150037090117009707L;

    /** ARFF job */
    protected ArffHeaderSparkJob m_arffHeaderJob = new ArffHeaderSparkJob();

    /** Holds the options to the header task */
    protected String m_wekaCsvToArffMapTaskOpts = "";

    /** No output to the file system by default */
    protected boolean m_writeRandomizedDataToOutput;

    /** The number of randomly shuffled splits to create */
    protected String m_numSplits = "10";

    /** random number seed option */
    protected String m_seed = "1";

    /** class attribute (for stratification) */
    protected String m_classAttribute = "";

    /**
     * If true then assume the last attribute is the class (when
     * not explicitly set)
     */
    protected boolean m_defaultToLastAttIfClassNotSpecified;

    /**
     * Holds the randomized and stratified RDD resulting from this job if the
     * input data does not contain string or relational attributes
     */
    protected JavaRDD<Instance> m_sortedByFold;

    /** Holds the final path for the randomly shuffled data chunks */
    protected String m_finalOutputDirectory = "";

    // TODO make an option to set the number slices based on the requested number
    // of instances per slice (perhaps this should be a general SparkJob option?).
    // if this gets used then the client will need to get the final number of
    // slices from us (so that subsequent jobs are in sync with the randomized
    // partitions created by this job)
    /**
     * Number of instances per data slice (determines how many data slices are
     * created)
     */
    protected String m_numInstancesPerInputSlice = "";

    /**
     * True if the output directory should be deleted first (doing so will force
     * this job to run in the case where there are already chunk files present in
     * the output directory)
     */
    // TODO - might revisit sometime if we can work out
    // how to force the number of partitions in an RDD
    // to be equal to the number files read from a directory.
    // Otherwise, our carefully shuffled and stratified splits
    // are likely to get split up further
    // protected boolean m_cleanOutputDir;

    /**
     * Whether to save randomly shuffled/stratified splits in text rather than
     * serialized object form
     */
    protected boolean m_saveTextualSplits;

    /**
     * Constructor
     */
    public RandomizedDataSparkJob() {
        super("Randomized and stratified split job", "Randomize and stratifify the input data");
    }

    public static void main(String[] args) {
        RandomizedDataSparkJob rdj = new RandomizedDataSparkJob();
        rdj.run(rdj, args);
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> result = new Vector<Option>();

        /*
         * result .add(new Option(
         * "\tClean output directory. Forces the job to run " +
         * "in the case where the output directory\n\talready exists and is populated "
         * + "with chunk files", "clean", 0, "-clean"));
         */

        result.add(new Option("\tNumber of shuffled folds/splits to create (default = 10)", "num-splits", 1,
                "-num-splits <integer>"));

        result.add(new Option("\tClass index (1-based) or class attribute name " + "(default = last attribute).",
                "class", 1, "-class <index or name>"));

        result.add(new Option("\tRandom seed (default = 1)", "seed", 1, "-seed <integer>"));

        result.add(new Option("\tOutput randomly shuffled/stratified split files to " + "\n\tthe output directory.",
                "output", 0, "-output"));

        result.add(new Option(
                "\tPersist split files as text (CSV) instead of binary."
                        + "\n\tHas no affect if -output has been not been specified.",
                "output-csv", 0, "-output-csv"));

        result.add(new Option("", "", 0, "\nOptions specific to ARFF training header creation:"));

        ArffHeaderSparkJob tempArffJob = new ArffHeaderSparkJob();
        Enumeration<Option> arffOpts = tempArffJob.listOptions();
        while (arffOpts.hasMoreElements()) {
            result.add(arffOpts.nextElement());
        }

        return result.elements();
    }

    @Override
    public String[] getOptions() {
        List<String> options = new ArrayList<String>();

        for (String opt : getJobOptionsOnly()) {
            options.add(opt);
        }

        if (!DistributedJobConfig.isEmpty(getCSVMapTaskOptions())) {
            try {
                String[] csvOpts = Utils.splitOptions(getCSVMapTaskOptions());

                for (String s : csvOpts) {
                    options.add(s);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        return options.toArray(new String[options.size()]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {

        String numInstancesPerSplit = Utils.getOption("num-instances-per-slice", options);
        setNumInstancesPerShuffledSplit(numInstancesPerSplit);

        setNumRandomlyShuffledSplits(Utils.getOption("num-splits", options));

        String className = Utils.getOption("class", options);
        setClassAttribute(className);

        String seed = Utils.getOption("seed", options);
        setRandomSeed(seed);

        // setCleanOutputDirectory(Utils.getFlag("clean", options));

        setWriteSplitsToOutput(Utils.getFlag("output", options));
        setPersistSplitFilesAsCSV(Utils.getFlag("output-csv", options));

        // copy the options at this point so that we can set
        // general configuration for the ARFF header
        // job
        String[] optionsCopy = options.clone();

        super.setOptions(options);

        // options for the ARFF header job
        m_arffHeaderJob.setOptions(optionsCopy);
        String optsToCSVTask = Utils.joinOptions(m_arffHeaderJob.getOptions());
        if (!DistributedJobConfig.isEmpty(optsToCSVTask)) {
            setCSVMapTaskOptions(optsToCSVTask);
        }
    }

    public String[] getJobOptionsOnly() {
        List<String> options = new ArrayList<String>();

        /*
         * if (getCleanOutputDirectory()) { options.add("-clean"); }
         */

        if (!DistributedJobConfig.isEmpty(getClassAttribute())) {
            options.add("-class");
            options.add(getClassAttribute());
        }

        if (!DistributedJobConfig.isEmpty(getRandomSeed())) {
            options.add("-seed");
            options.add(getRandomSeed());
        }

        if (!DistributedJobConfig.isEmpty(getNumRandomlyShuffledSplits())) {
            options.add("-num-splits");
            options.add(getNumRandomlyShuffledSplits());
        }

        if (!DistributedJobConfig.isEmpty(getNumInstancesPerShuffledSplit())) {
            options.add("-num-instances-per-slice");
            options.add(getNumInstancesPerShuffledSplit());
        }

        if (getWriteSplitsToOutput()) {
            options.add("-output");
        }

        if (getPersistSplitFilesAsCSV()) {
            options.add("-output-csv");
        }

        return options.toArray(new String[options.size()]);
    }

    public boolean getDontDefaultToLastAttIfClassNotSpecified() {
        return m_defaultToLastAttIfClassNotSpecified;
    }

    public void setDefaultToLastAttIfClassNotSpecified(boolean d) {
        m_defaultToLastAttIfClassNotSpecified = d;
    }

    /**
     * Tip text for this property
     *
     * @return the tip text for this property
     */
    public String numRandomlyShuffledSplitsTipText() {
        return "The number of randomly shuffled splits to create";
    }

    /**
     * Get the number of randmly shuffled splits to make (if randomly shuffling
     * the data)
     *
     * @return the number of randomly shuffled splits to make
     */
    public String getNumRandomlyShuffledSplits() {
        return m_numSplits;
    }

    /**
     * Set the number of randomly shuffled splits to make (if randomly shuffling
     * the data)
     *
     * @param s the number of randomly shuffled splits to make
     */
    public void setNumRandomlyShuffledSplits(String s) {
        m_numSplits = s;
    }

    /**
     * Tip text for this property
     *
     * @return the tip text for this property
     */
    public String writeSplitsToOutputTipText() {
        return "Output shuffled/stratified split files?";
    }

    /**
     * Get whether to output randomly shuffled split files
     * 
     * @return true if splits are to be output
     */
    public boolean getWriteSplitsToOutput() {
        return m_writeRandomizedDataToOutput;
    }

    /**
     * Set whether to output randomly shuffled split files
     *
     * @param d true if splits are to be output
     */
    public void setWriteSplitsToOutput(boolean d) {
        m_writeRandomizedDataToOutput = d;
    }

    /**
     * Tip text for this property
     *
     * @return the tip text for this property
     */
    public String persistSplitFilesAsCSV() {
        return "Persist randomly shuffled split files in textual CSV form instead of binary";
    }

    /**
     * Get whether to persist randomly shuffled data chunks as CSV files rather
     * than serialized objects
     * 
     * @return true if data chunks should be persisted as CSV rather than object
     *         files
     */
    public boolean getPersistSplitFilesAsCSV() {
        return m_saveTextualSplits;
    }

    /**
     * Set whether to persist randomly shuffled data chunks as CSV files rather
     * than serialized objects
     *
     * @param s true if data chunks should be persisted as CSV rather than object
     *          files
     */
    public void setPersistSplitFilesAsCSV(boolean s) {
        m_saveTextualSplits = s;
    }

    /**
     * Tip text for this property
     *
     * @return tip text for this property
     */
    public String numInstancesPerShuffledSplitTipText() {
        return "The number of instances that each randomly shuffled data split (partition) "
                + "should contain. If not specified then the number of instances per chunk"
                + "is determined by the min-slices setting";
    }

    /**
     * Get the number of instances that each randomly shuffled data chunk should
     * have.
     *
     * @return the number of instances that each randomly shuffled data chunk
     *         should contain
     */
    public String getNumInstancesPerShuffledSplit() {
        return m_numInstancesPerInputSlice;
    }

    /**
     * Set the number of instances that each randomly shuffled data chunk should
     * have.
     *
     * @param insts the number of instances that each randomly shuffled data chunk
     *          should contain
     */
    public void setNumInstancesPerShuffledSplit(String insts) {
        m_numInstancesPerInputSlice = insts;
    }

    /**
     * Tip text for this property
     *
     * @return tip text for this property
     */
    /*
     * public String cleanOutputDirectoryTipText() { return
     * "Set to true to delete any existing output directory and force this job to run"
     * ; }
     */

    /**
     * Get whether to blow away the output directory before running. If an output
     * directory exists (and is populated with chunk files) then deleting this
     * prior to running will force the job to run.
     * 
     * @return true if the output directory should be deleted before first (thus
     *         forcing the job to run if there was a populated output directory
     *         already).
     */
    /*
     * public boolean getCleanOutputDirectory() { return m_cleanOutputDir; }
     */

    /**
     * Set whether to blow away the output directory before running. If an output
     * directory exists (and is populated with chunk files) then deleting this
     * prior to running will force the job to run.
     *
     * @param clean true if the output directory should be deleted before first
     *          (thus forcing the job to run if there was a populated output
     *          directory already).
     */
    /*
     * public void setCleanOutputDirectory(boolean clean) { m_cleanOutputDir =
     * clean; }
     */

    /**
     * Tip text for this property
     *
     * @return tip text for this property
     */
    public String randomSeedTipText() {
        return "The random seed to use when randomly shuffling the data";
    }

    /**
     * Get the random seed for shuffling the data
     * 
     * @return the random seed to use
     */
    public String getRandomSeed() {
        return m_seed;
    }

    /**
     * Set the random seed for shuffling the data
     *
     * @param seed the random seed to use
     */
    public void setRandomSeed(String seed) {
        m_seed = seed;
    }

    /**
     * Tip text for this property
     *
     * @return tip text for this property
     */
    public String classAttributeTipText() {
        return "The name or index of the class attribute. 'first' and "
                + "'last' may also be used. If set, and the class is nominal, "
                + "then output chunk files will be stratified.";
    }

    /**
     * Get the name or index of the class attribute ("first" and "last" can also
     * be used)
     * 
     * @return the name or index of the class attribute
     */
    public String getClassAttribute() {
        return m_classAttribute;
    }

    /**
     * Set the name or index of the class attribute ("first" and "last" can also
     * be used)
     *
     * @param c the name or index of the class attribute
     */
    public void setClassAttribute(String c) {
        m_classAttribute = c;
    }

    /**
     * Get the options to the header job
     *
     * @return options to the header job
     */
    public String getCSVMapTaskOptions() {
        return m_wekaCsvToArffMapTaskOpts;
    }

    /**
     * Set the options to the header job
     *
     * @param opts options to the header job
     */
    public void setCSVMapTaskOptions(String opts) {
        m_wekaCsvToArffMapTaskOpts = opts;
    }

    /**
     * Perform the randomization (and stratification) in the case where the input
     * data does not contain string or relational attributes. In this case, our
     * final RDD can contain instances objects, which will avoid further parsing
     * in subsequent jobs.
     *
     * @param input
     * @param outputPath
     * @param numFoldSlices
     * @param random
     * @param headerWithSummary
     * @param classIndex the classIndex to use
     * @throws IOException
     * @throws DistributedWekaException
     */
    protected void performRandomShuffle(JavaRDD<Instance> input, String outputPath, final int numFoldSlices,
            final Random random, final Instances headerWithSummary, int classIndex)
            throws IOException, DistributedWekaException {

        final Instances headerNoSummary = CSVToARFFHeaderReduceTask.stripSummaryAtts(headerWithSummary);
        headerNoSummary.setClassIndex(classIndex);

        boolean containsStringOrRelational = headerNoSummary.checkForStringAttributes()
                || headerNoSummary.checkForAttributeType(Attribute.RELATIONAL);

        final PhaseOneRandomization phaseOne = new PhaseOneRandomization(headerNoSummary, getCSVMapTaskOptions(),
                containsStringOrRelational, random, numFoldSlices);

        // Phase 1 - randomly shuffle the data
        logMessage("[Randomly shuffle data] Starting phase 1...");
        JavaPairRDD<Integer, Object> mapResults = input
                .mapPartitionsToPair(new PairFlatMapFunction<Iterator<Instance>, Integer, Object>() {

                    /** For serialization */
                    private static final long serialVersionUID = -5351850875358513817L;

                    protected List<Tuple2<Integer, Object>> m_randomizedRows = new ArrayList<Tuple2<Integer, Object>>();

                    // protected CSVToARFFHeaderMapTask m_rowHelper;

                    @Override
                    public Iterable<Tuple2<Integer, Object>> call(Iterator<Instance> split)
                            throws IOException, DistributedWekaException {

                        while (split.hasNext()) {
                            Instance row = split.next();

                            Tuple2<Integer, Object> processed = phaseOne.process(row);
                            m_randomizedRows.add(processed);
                        }

                        // System.err.println("****** Number in partition: " + m_count);
                        return m_randomizedRows;
                    }
                }).persist(getCachingStrategy().getStorageLevel());

        // Now sort into ascending order of random assignment number
        JavaPairRDD<Integer, Object> sortedByAssignment = mapResults.sortByKey(true)
                .partitionBy(new IntegerKeyPartitioner(numFoldSlices))
                .persist(getCachingStrategy().getStorageLevel());
        sortedByAssignment.count();

        // discard mapResults
        mapResults.unpersist();
        mapResults = null;

        // List<Tuple2<Integer, Object>> tmpData = sortedByAssignment.collect();
        // for (Tuple2<Integer, Object> row : tmpData) {
        // ((Instance) row._2()).setDataset(headerNoSummary);
        // System.err.println(row._1() + ": " + row._2().toString());
        // }

        if (headerNoSummary.classIndex() < 0 || headerNoSummary.classAttribute().isNumeric()) {

            // No need for the second phase of dealing classes out to splits
            // if there is no class or a numeric class
            // m_sortedByFold = sortedByAssignment;
            // , true here because we preserve the partitions from sortedByAssignment
            JavaRDD<Instance> finalDataSet = sortedByAssignment
                    .mapPartitions(new FlatMapFunction<Iterator<Tuple2<Integer, Object>>, Instance>() {

                        /**
                         * For serialization
                         */
                        private static final long serialVersionUID = -4129157509045217459L;
                        List<Instance> m_list = new ArrayList<Instance>();

                        @Override
                        public Iterable<Instance> call(Iterator<Tuple2<Integer, Object>> split) {
                            while (split.hasNext()) {

                                // make sure each instance has a reference to the header
                                Instance nextI = (Instance) split.next()._2();
                                nextI.setDataset(headerNoSummary);
                                m_list.add(nextI);
                            }

                            return m_list;
                        }
                    }, true).persist(getCachingStrategy().getStorageLevel());

            finalDataSet.count(); // materialize this RDD

            logMessage("[Randomly shuffle data] Unpersisting sorted phase 1 RDD");
            sortedByAssignment.unpersist();
            sortedByAssignment = null;

            m_sortedByFold = finalDataSet;

        } else {
            // phase 2 - deal classes out to splits + oversample minority classes
            final PhaseTwoStratification phaseTwo = new PhaseTwoStratification(headerNoSummary, numFoldSlices,
                    false);

            logMessage("[Randomly shuffle data] Starting phase 2 (deal to folds/stratification)...");
            JavaPairRDD<Integer, Object> dealtToFolds = sortedByAssignment.mapPartitionsToPair(
                    new PairFlatMapFunction<Iterator<Tuple2<Integer, Object>>, Integer, Object>() {

                        /**
                         * For serialization
                         */
                        private static final long serialVersionUID = -5903374381393577497L;

                        protected List<Tuple2<Integer, Object>> m_dealtRows = new ArrayList<Tuple2<Integer, Object>>();

                        @Override
                        public Iterable<Tuple2<Integer, Object>> call(Iterator<Tuple2<Integer, Object>> split) {

                            while (split.hasNext()) {
                                Tuple2<Integer, Object> current = split.next();
                                Tuple2<Integer, Object> result = phaseTwo.process(current._2);

                                m_dealtRows.add(result);
                            }

                            phaseTwo.checkForMinorityClassCases(m_dealtRows);

                            return m_dealtRows;
                        }
                    }).persist(getCachingStrategy().getStorageLevel());

            // discard sortedByAssignment

            logMessage("[Randomly shuffle data] Repartitioning phase 2 RDD according to fold number");
            JavaPairRDD<Integer, Object> tmpSortedByFold = dealtToFolds.sortByKey()
                    .partitionBy(new IntegerKeyPartitioner(numFoldSlices))
                    .persist(getCachingStrategy().getStorageLevel());

            tmpSortedByFold.count();

            sortedByAssignment.unpersist();
            sortedByAssignment = null;
            dealtToFolds.unpersist();
            dealtToFolds = null;

            // writeRandomizedSplits(outputPath, dealtToFolds);

            // List<Tuple2<Integer, Object>> tmpData = dealtToFolds.collect();
            // for (Tuple2<Integer, Object> row : tmpData) {
            // ((Instance) row._2()).setDataset(headerNoSummary);
            // System.err.println(row._1() + ": " + row._2().toString());
            // }

            // m_sortedByFold = dealtToFolds.sortByKey(true);
            logMessage("[Randomly shuffle data] Creating and persisting final dataset (RDD<Instance>)...");
            JavaRDD<Instance> finalDataSet = null;

            if (!containsStringOrRelational) {
                finalDataSet = tmpSortedByFold
                        .mapPartitions(new FlatMapFunction<Iterator<Tuple2<Integer, Object>>, Instance>() {

                            /**
                             * For serialization
                             */
                            private static final long serialVersionUID = 5425826829981136102L;
                            List<Instance> m_list = new ArrayList<Instance>();

                            @Override
                            public Iterable<Instance> call(Iterator<Tuple2<Integer, Object>> split) {
                                while (split.hasNext()) {

                                    // make sure that each instance has access to the header
                                    Instance nextI = (Instance) split.next()._2();
                                    nextI.setDataset(headerNoSummary);
                                    m_list.add(nextI);
                                }

                                return m_list;
                            }
                        }, true).persist(getCachingStrategy().getStorageLevel());
            } else {
                CSVToInstancePairFlatMapFunction instanceFunction = new CSVToInstancePairFlatMapFunction(
                        headerNoSummary, getCSVMapTaskOptions());

                // , true here because we preserve the partitions from tmpSortedByFold
                finalDataSet = tmpSortedByFold.mapPartitions(instanceFunction, true)
                        .persist(getCachingStrategy().getStorageLevel());
            }

            logMessage("[Randomly shuffle data] forcing materialization of final shuffled data");
            finalDataSet.count();

            logMessage("[Randomly shuffle data] Unpersisting intermediate RDDs");

            tmpSortedByFold.unpersist();
            tmpSortedByFold = null;

            m_sortedByFold = finalDataSet;
            logMessage("[Randomly shuffle data] Finished shuffling/stratifying RDD. Number of partitions: "
                    + m_sortedByFold.partitions().size());
        }

        setDataset(TRAINING_DATA, new Dataset<Instance>(m_sortedByFold, headerWithSummary));

        if (m_writeRandomizedDataToOutput) {
            writeRandomizedSplits(outputPath, m_sortedByFold);
        }
    }

    // TODO we might remove this entire mechanism
    protected void loadShuffledDataFiles(String outputDir, JavaSparkContext sparkContext, Instances headerNoSummary,
            int minSlices) throws IOException, DistributedWekaException {

        // update caching strategy based on shuffled data...
        CachingStrategy newStrategy = new CachingStrategy(outputDir, m_sjConfig.getAvailableClusterMemory(),
                m_sjConfig.getInMemoryDataOverheadFactor(), sparkContext);
        setCachingStrategy(newStrategy);
        logMessage("[Randomly shuffle data] Setting caching strategy to: "
                + getCachingStrategy().getStorageLevel().description());

        if (getPersistSplitFilesAsCSV()) {
            m_sortedByFold = loadCSVFile(outputDir, headerNoSummary, getCSVMapTaskOptions(), sparkContext,
                    getCachingStrategy(), -1, true);
        } else {
            m_sortedByFold = loadInstanceObjectFile(outputDir, sparkContext, getCachingStrategy(), minSlices, true);
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public boolean runJobWithContext(JavaSparkContext sparkContext) throws IOException, DistributedWekaException {

        m_currentContext = sparkContext;
        setJobStatus(JobStatus.RUNNING);
        boolean success = false;

        if (m_env == null) {
            m_env = Environment.getSystemWide();
        }

        // Make sure that we save out to a subdirectory of the output
        // directory
        String outputPath = environmentSubstitute(m_sjConfig.getOutputDir());
        outputPath = addSubdirToPath(outputPath, OUTPUT_SUBDIR);

        JavaRDD<Instance> inputData = null;
        Instances headerWithSummary = null;
        if (getDataset(TRAINING_DATA) != null) {
            inputData = (JavaRDD<Instance>) getDataset(TRAINING_DATA).getDataset();
            headerWithSummary = getDataset(TRAINING_DATA).getHeaderWithSummary();
            logMessage("RDD<Instance> dataset provided: " + inputData.partitions().size() + " partitions.");
        }

        if (inputData == null && headerWithSummary == null) {
            logMessage("[Randomly shuffle data] Invoking ARFF Job...");
            m_arffHeaderJob.setEnvironment(m_env);
            m_arffHeaderJob.setLog(getLog());
            m_arffHeaderJob.setStatusMessagePrefix(m_statusMessagePrefix);
            m_arffHeaderJob.setCachingStrategy(getCachingStrategy());

            // header job necessary?
            success = m_arffHeaderJob.runJobWithContext(sparkContext);

            if (!success) {
                setJobStatus(JobStatus.FAILED);
                statusMessage("Unable to continue - creating the ARFF header failed!");
                logMessage("[Randomly shuffle data] Unable to continue - creating the ARFF header failed!");
                return false;
            }

            Dataset d = m_arffHeaderJob.getDataset(TRAINING_DATA);

            headerWithSummary = d.getHeaderWithSummary();
            inputData = (JavaRDD<Instance>) d.getDataset();
            logMessage("Fetching RDD<Instance> dataset from ARFF job: " + inputData.partitions().size()
                    + " partitions.");
        }

        /*
         * // TODO revisit at some stage... Current assumption: if you // have
         * output from this job as serialized instances then you // are happy with
         * the shuffling and will not want to re-shuffle if
         * (m_sjConfig.getSerializedInput()) { throw new DistributedWekaException(
         * "Randomly shuffling serialized Instance " +
         * "input is not supported yet."); }
         */

        // clean the output directory
        SparkJob.deleteDirectory(outputPath);

        int seed = 1;
        if (!DistributedJobConfig.isEmpty(getRandomSeed())) {
            seed = Integer.parseInt(environmentSubstitute(getRandomSeed()));
        }
        final Instances headerNoSummary = CSVToARFFHeaderReduceTask.stripSummaryAtts(headerWithSummary);

        try {
            WekaClassifierSparkJob.setClassIndex(environmentSubstitute(m_classAttribute), headerNoSummary,
                    m_defaultToLastAttIfClassNotSpecified);

        } catch (Exception e) {
            logMessage(e);
            throw new DistributedWekaException(e);
        }

        // find summary attribute for class (if set, otherwise just use the first
        // numeric on nominal attribute). We're using this simply to find out
        // the total number of instances in the dataset
        String className = null;
        if (headerNoSummary.classIndex() >= 0) {
            className = headerNoSummary.classAttribute().name();
        } else {
            for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
                if (headerNoSummary.attribute(i).isNumeric() || headerNoSummary.attribute(i).isNominal()) {
                    className = headerNoSummary.attribute(i).name();
                    break;
                }
            }
        }
        Attribute summaryClassAtt = headerWithSummary
                .attribute(CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX + className);
        if (summaryClassAtt == null) {
            throw new DistributedWekaException(
                    "Was unable to find the summary attribute for " + "the class: " + className);
        }

        int totalNumInstances = 0;
        int numFoldSlices = 10;
        // summary attribute for getting the total number of instances
        Attribute summaryAttOrig = null;
        for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
            if (headerNoSummary.attribute(i).isNumeric() || headerNoSummary.attribute(i).isNominal()) {
                summaryAttOrig = headerNoSummary.attribute(i);
                break;
            }
        }

        if (summaryAttOrig == null) {
            throw new DistributedWekaException("Was unable to find a summary attribute for "
                    + "deterimining the total number of instances in the dataset");
        }

        String summaryName = summaryAttOrig.name();
        Attribute summaryAtt = headerWithSummary
                .attribute(CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX + summaryName);
        if (summaryAtt == null) {
            logMessage("[RandomizedDataSparkJob] Was unable to find the summary " + "attribute for attribute: "
                    + summaryName);
            throw new DistributedWekaException(
                    "Was unable to find the summary " + "attribute for attribute: " + summaryName);
        }

        if (summaryAttOrig.isNominal()) {
            NominalStats stats = NominalStats.attributeToStats(summaryAtt);
            for (String label : stats.getLabels()) {
                totalNumInstances += stats.getCount(label);
            }
        } else {
            NumericStats stats = NumericStats.attributeToStats(summaryAtt);
            totalNumInstances = (int) stats.getStats()[ArffSummaryNumericMetric.COUNT.ordinal()];
        }

        if (DistributedJobConfig.isEmpty(getNumRandomlyShuffledSplits())
                && DistributedJobConfig.isEmpty(getNumInstancesPerShuffledSplit())) {
            logMessage("[RandomizedDataSparkJob] Must specify either the number of "
                    + "splits or the number of instances per split");
            throw new DistributedWekaException(
                    "Must specify either the number of " + "splits or the number of instances per split");
        }

        if (!DistributedJobConfig.isEmpty(getNumRandomlyShuffledSplits())) {
            numFoldSlices = Integer.parseInt(environmentSubstitute(getNumRandomlyShuffledSplits()));
        } else {
            int numInsts = 0;
            try {
                numInsts = Integer.parseInt(environmentSubstitute(getNumInstancesPerShuffledSplit()));
            } catch (NumberFormatException ex) {
                throw new DistributedWekaException(ex);
            }

            if (numInsts <= 0) {
                throw new DistributedWekaException("Number of instances per split must " + "be > 0");
            }

            if (numInsts > totalNumInstances) {
                throw new DistributedWekaException(
                        "Can't have more instances per split " + "than there are instances in the dataset!");
            }
            double nc = (double) totalNumInstances / numInsts;
            nc = Math.ceil(nc);
            numFoldSlices = (int) nc;
        }
        logMessage("[Randomly shuffle data] creating " + numFoldSlices + " splits.");

        if (headerNoSummary.attribute(className).isNominal()) {
            NominalStats stats = NominalStats.attributeToStats(summaryClassAtt);
            for (String label : stats.getLabels()) {
                totalNumInstances += stats.getCount(label);
            }
        } else {
            NumericStats stats = NumericStats.attributeToStats(summaryClassAtt);
            totalNumInstances = (int) stats.getStats()[ArffSummaryNumericMetric.COUNT.ordinal()];
        }

        logMessage("[Randomly shuffle data] Num slices = " + numFoldSlices);

        final Random random = new Random(seed);
        for (int i = 0; i < 20; i++) {
            random.nextInt();
        }

        performRandomShuffle(inputData, outputPath, numFoldSlices, random, headerWithSummary,
                headerNoSummary.classIndex());

        setJobStatus(JobStatus.FINISHED);

        return true;
    }

    /**
     * Writes the randomized splits out to the output directory. Spark ensures
     * that each key goes to a separate part-xxxxx file in the output.
     *
     * @param outputPath the path to write to
     * @param sortedByFold the RDD containing the sorted dataset keyed by fold
     * @throws IOException if a problem occurs
     */
    protected void writeRandomizedSplits(String outputPath, JavaRDD<Instance> sortedByFold) throws IOException {
        String[] pathOnly = new String[1];
        Configuration conf = getFSConfigurationForPath(outputPath, pathOnly);

        FileSystem fs = FileSystem.get(conf);
        Path p = new Path(pathOnly[0]);
        p = p.makeQualified(fs);

        if (!p.toString().toLowerCase().startsWith("file:/")) {
            m_finalOutputDirectory = p.toString();
        } else {
            m_finalOutputDirectory = pathOnly[0];
        }

        if (m_saveTextualSplits) {
            sortedByFold.saveAsTextFile(m_finalOutputDirectory);
        } else {
            sortedByFold.saveAsObjectFile(m_finalOutputDirectory);
        }
        // sortedByFold.saveAsNewAPIHadoopFile(pathOnly[0], NullWritable.class,
        // Text.class, SparkJob.NoKeyTextOutputFormat.class, conf);

    }

    protected void writeRandomizedSplits(String outputPath, JavaPairRDD<Integer, Object> sortedByFold)
            throws IOException {
        String[] pathOnly = new String[1];
        Configuration conf = getFSConfigurationForPath(outputPath, pathOnly);

        FileSystem fs = FileSystem.get(conf);
        Path p = new Path(pathOnly[0]);
        p = p.makeQualified(fs);

        if (!p.toString().toLowerCase().startsWith("file:/")) {
            m_finalOutputDirectory = p.toString();
        } else {
            m_finalOutputDirectory = pathOnly[0];
        }

        // sortedByFold.saveAsNewAPIHadoopFile(pathOnly[0], NullWritable.class,
        // Text.class, SparkJob.NoKeyTextOutputFormat.class, conf);
        //
        sortedByFold.saveAsNewAPIHadoopFile(pathOnly[0], IntWritable.class, Text.class, TextOutputFormat.class,
                conf);
    }

    /**
     * Get the randomized (and potentially stratified) RDD resulting from this job
     *
     * @return the randomly shuffled (and stratified) RDD
     */
    public JavaRDD<Instance> getRandomizedStratifiedRDD() {
        // return m_sortedByFold.persist(StorageLevel.MEMORY_AND_DISK());
        return m_sortedByFold;
    }

    public String getFinalOutputDirectory() {

        // TODO adjust this with respect to protocol?? Problem
        // here because RDD.input() doesn't have a version where
        // we can pass in a path + Configuration (thus allowing us
        // to set up relative (to user's home) paths)

        return m_finalOutputDirectory;
    }

    @Override
    public void run(Object toRun, String[] options) throws IllegalArgumentException {

        if (!(toRun instanceof RandomizedDataSparkJob)) {
            throw new IllegalArgumentException("Object to run is not a RandomizedDataSparkJob");
        }

        try {
            RandomizedDataSparkJob rdj = (RandomizedDataSparkJob) toRun;
            if (Utils.getFlag('h', options)) {
                String help = DistributedJob.makeOptionsStr(rdj);
                System.err.println(help);
                System.exit(1);
            }

            rdj.setOptions(options);
            rdj.runJob();
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    /**
     * Inner class for phase one of the two phase random shuffle and stratify
     * process. Phase one performs the shuffling by assigning each instance to a
     * randomly selected chunk. RDD can then be sorted on chunk number (key) and
     * the IntegerKeyPartitioner used to re-partition according to key.
     *
     * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
     */
    protected static class PhaseOneRandomization implements Serializable {

        /** For serialization */
        private static final long serialVersionUID = 3355852963548447505L;

        /** The header of the data (sans summary attributes) */
        protected Instances m_headerNoSummary;

        /** Used for parsing CSV */
        protected CSVToARFFHeaderMapTask m_rowHelper;

        /** Options to the CSV parser */
        protected String m_csvMapTaskOpts = "";

        /** For random number generation */
        protected Random m_random;

        /** Whether to represent instances as strings or not */
        protected boolean m_instancesAsStrings;

        /** Number of folds (splits) of the data */
        protected int m_numFoldSlices;

        /**
         * Constructor
         *
         * @param headerNoSummary the header of the data (sans summary attributes)
         * @param csvMapTaskOpts options for the CSV parser
         * @param instancesAsStrings true if instances are to be kept as string
         *          objects
         * @param random for random number generation
         * @param numFoldSlices number of splits
         */
        public PhaseOneRandomization(final Instances headerNoSummary, final String csvMapTaskOpts,
                boolean instancesAsStrings, final Random random, final int numFoldSlices) {
            m_headerNoSummary = headerNoSummary;
            m_csvMapTaskOpts = csvMapTaskOpts;
            m_instancesAsStrings = instancesAsStrings;
            m_random = random;
            m_numFoldSlices = numFoldSlices;
        }

        /**
         * Process a row/instance
         *
         * @param toProcess the row to process
         * @return either a key,string tuple or a key,instance tuple
         * @throws DistributedWekaException if a problem occurs
         * @throws IOException if a problem occurs
         */
        public Tuple2<Integer, Object> process(String toProcess) throws DistributedWekaException, IOException {
            if (m_rowHelper == null) {
                m_rowHelper = new CSVToARFFHeaderMapTask();
                try {
                    m_rowHelper.setOptions(Utils.splitOptions(m_csvMapTaskOpts));
                    m_rowHelper.initParserOnly(
                            CSVToARFFHeaderMapTask.instanceHeaderToAttributeNameList(m_headerNoSummary));
                } catch (Exception e) {
                    throw new DistributedWekaException(e);
                }
            }

            Instance currentI = DistributedJob.parseInstance(toProcess, m_rowHelper, m_headerNoSummary,
                    m_instancesAsStrings);

            return process(currentI);
        }

        public Tuple2<Integer, Object> process(Instance currentI) {
            int chunkNum = 0;
            if (currentI.classIndex() < 0 || !currentI.isMissing(currentI.classIndex())) {
                chunkNum = m_random.nextInt(m_numFoldSlices);
            }

            Object instToAdd = null;
            if (m_instancesAsStrings) {
                instToAdd = currentI.toString();

                if (m_headerNoSummary.classIndex() >= 0 && m_headerNoSummary.classAttribute().isNominal()) {
                    instToAdd = (instToAdd.toString() + "@:@" + (int) currentI.classValue());
                }
            } else {
                // set the reference to the dataset to null here so that
                // we prevent many copies of the header being created when
                // Spark serializes stuff when we sort
                currentI.setDataset(null);
                instToAdd = currentI;
            }

            Tuple2<Integer, Object> result = new Tuple2<Integer, Object>(chunkNum, instToAdd);

            return result;
        }
    }

    /**
     * Inner class for phase two of the random shuffle and stratify process. This
     * implements the stratification process if there is a class set in the data
     * and it is nominal. In all other cases only phase one is needed.
     *
     * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
     */
    protected static class PhaseTwoStratification implements Serializable {

        /** For serialization */
        private static final long serialVersionUID = 6969249334933544891L;

        /** Header of the data (sans summary attributes) */
        protected Instances m_headerNoSummary;

        /**
         * True if instances are to remain in string form (rather than Instance
         * object form)
         */
        protected boolean m_instancesAsStrings;

        /** Number of slices/folds/chunks */
        protected int m_numFoldSlices;

        /** Keeps track of the number of instances seen in each class */
        protected int[] m_countsPerClass;

        /** Maintains a buffer for instances from minority classes */
        protected List<List<Object>> m_classInstancesBuffer;

        /** For randomization */
        protected Random m_random = new Random(42);

        public PhaseTwoStratification(Instances headerNoSummary, int numFoldSlices, boolean instancesAsStrings) {
            m_headerNoSummary = headerNoSummary;
            m_instancesAsStrings = instancesAsStrings;
            m_numFoldSlices = numFoldSlices;

            m_countsPerClass = new int[headerNoSummary.classIndex() < 0
                    || headerNoSummary.classAttribute().isNumeric() ? 1
                            : headerNoSummary.classAttribute().numValues()];

            if (headerNoSummary.classIndex() >= 0 && headerNoSummary.classAttribute().isNominal()) {
                m_classInstancesBuffer = new ArrayList<List<Object>>();
                int numClasses = headerNoSummary.classAttribute().numValues();
                for (int i = 0; i < numClasses; i++) {
                    m_classInstancesBuffer.add(new ArrayList<Object>());
                }
            }
        }

        protected int assignToFold(Object toProcess, int classIndex) {
            int fold = m_countsPerClass[classIndex] % m_numFoldSlices;

            m_countsPerClass[classIndex]++;

            // add to the minority class buffers. If we've seen
            // at least m_numberOfDataChunks instances for this class then
            // each data chunk will have at least one instance of this class
            if (m_countsPerClass[classIndex] < m_numFoldSlices) {
                m_classInstancesBuffer.get(classIndex).add(toProcess);
            }
            return fold;
        }

        public Tuple2<Integer, Object> process(Object toProcess) {

            int classVal = 0;
            Object inst = null;
            if (m_instancesAsStrings) {
                inst = toProcess.toString();

                if (m_headerNoSummary.classIndex() >= 0 && m_headerNoSummary.classAttribute().isNominal()) {
                    String[] parts = inst.toString().split("@:@");
                    inst = parts[0];
                    classVal = Integer.parseInt(parts[1]);
                }
            } else {
                inst = toProcess;

                if (m_headerNoSummary.classIndex() >= 0 && m_headerNoSummary.classAttribute().isNominal()) {
                    classVal = (int) ((Instance) inst).value(m_headerNoSummary.classIndex());
                    // classVal = (int) ((Instance) inst).classValue();
                }
            }

            int fold = assignToFold(inst, classVal);

            return new Tuple2<Integer, Object>(fold, inst);
        }

        public void checkForMinorityClassCases(List<Tuple2<Integer, Object>> toAddTo) {
            // folds that have not received minority class instances?
            for (int i = 0; i < m_countsPerClass.length; i++) {
                if (m_countsPerClass[i] < m_numFoldSlices && m_countsPerClass[i] > 0) {
                    while (m_countsPerClass[i] < m_numFoldSlices) {
                        int instIndex = m_random.nextInt(m_classInstancesBuffer.get(i).size());

                        toAddTo.add(new Tuple2<Integer, Object>(m_countsPerClass[i],
                                m_classInstancesBuffer.get(i).get(instIndex)));

                        m_countsPerClass[i]++;
                    }
                }
            }
        }
    }
}