org.apache.mahout.classifier.df.mapreduce.Classifier.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.df.mapreduce.Classifier.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.mapreduce;

import java.io.IOException;
import java.util.List;
import java.util.Random;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.classifier.df.DFUtils;
import org.apache.mahout.classifier.df.DecisionForest;
import org.apache.mahout.classifier.df.data.DataConverter;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Mapreduce implementation that classifies the Input data using a previousely built decision forest
 */
public class Classifier {

    private static final Logger log = LoggerFactory.getLogger(Classifier.class);

    private final Path forestPath;
    private final Path inputPath;
    private final Path datasetPath;
    private final Configuration conf;
    private final Path outputPath; // path that will containt the final output of the classifier
    private final Path mappersOutputPath; // mappers will output here
    private double[][] results;

    public double[][] getResults() {
        return results;
    }

    public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf) {
        this.forestPath = forestPath;
        this.inputPath = inputPath;
        this.datasetPath = datasetPath;
        this.outputPath = outputPath;
        this.conf = conf;

        mappersOutputPath = new Path(outputPath, "mappers");
    }

    private void configureJob(Job job) throws IOException {

        job.setJarByClass(Classifier.class);

        FileInputFormat.setInputPaths(job, inputPath);
        FileOutputFormat.setOutputPath(job, mappersOutputPath);

        job.setOutputKeyClass(DoubleWritable.class);
        job.setOutputValueClass(Text.class);

        job.setMapperClass(CMapper.class);
        job.setNumReduceTasks(0); // no reducers

        job.setInputFormatClass(CTextInputFormat.class);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);

    }

    public void run() throws IOException, ClassNotFoundException, InterruptedException {
        FileSystem fs = FileSystem.get(conf);

        // check the output
        if (fs.exists(outputPath)) {
            throw new IOException("Output path already exists : " + outputPath);
        }

        log.info("Adding the dataset to the DistributedCache");
        // put the dataset into the DistributedCache
        DistributedCache.addCacheFile(datasetPath.toUri(), conf);

        log.info("Adding the decision forest to the DistributedCache");
        DistributedCache.addCacheFile(forestPath.toUri(), conf);

        Job job = new Job(conf, "decision forest classifier");

        log.info("Configuring the job...");
        configureJob(job);

        log.info("Running the job...");
        if (!job.waitForCompletion(true)) {
            throw new IllegalStateException("Job failed!");
        }

        parseOutput(job);

        HadoopUtil.delete(conf, mappersOutputPath);
    }

    /**
     * Extract the prediction for each mapper and write them in the corresponding output file. 
     * The name of the output file is based on the name of the corresponding input file.
     * Will compute the ConfusionMatrix if necessary.
     */
    private void parseOutput(JobContext job) throws IOException {
        Configuration conf = job.getConfiguration();
        FileSystem fs = mappersOutputPath.getFileSystem(conf);

        Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath);

        // read all the output
        List<double[]> resList = Lists.newArrayList();
        for (Path path : outfiles) {
            FSDataOutputStream ofile = null;
            try {
                for (Pair<DoubleWritable, Text> record : new SequenceFileIterable<DoubleWritable, Text>(path, true,
                        conf)) {
                    double key = record.getFirst().get();
                    String value = record.getSecond().toString();
                    if (ofile == null) {
                        // this is the first value, it contains the name of the input file
                        ofile = fs.create(new Path(outputPath, value).suffix(".out"));
                    } else {
                        // The key contains the correct label of the data. The value contains a prediction
                        ofile.writeChars(value); // write the prediction
                        ofile.writeChar('\n');

                        resList.add(new double[] { key, Double.valueOf(value) });
                    }
                }
            } finally {
                Closeables.close(ofile, false);
            }
        }
        results = new double[resList.size()][2];
        resList.toArray(results);
    }

    /**
     * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single
     * mapper.
     */
    private static class CTextInputFormat extends TextInputFormat {
        @Override
        protected boolean isSplitable(JobContext jobContext, Path path) {
            return false;
        }
    }

    public static class CMapper extends Mapper<LongWritable, Text, DoubleWritable, Text> {

        /** used to convert input values to data instances */
        private DataConverter converter;
        private DecisionForest forest;
        private final Random rng = RandomUtils.getRandom();
        private boolean first = true;
        private final Text lvalue = new Text();
        private Dataset dataset;
        private final DoubleWritable lkey = new DoubleWritable();

        @Override
        protected void setup(Context context) throws IOException, InterruptedException {
            super.setup(context); //To change body of overridden methods use File | Settings | File Templates.

            Configuration conf = context.getConfiguration();

            Path[] files = HadoopUtil.getCachedFiles(conf);

            if (files.length < 2) {
                throw new IOException("not enough paths in the DistributedCache");
            }
            dataset = Dataset.load(conf, files[0]);
            converter = new DataConverter(dataset);

            forest = DecisionForest.load(conf, files[1]);
            if (forest == null) {
                throw new InterruptedException("DecisionForest not found!");
            }
        }

        @Override
        protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
            if (first) {
                FileSplit split = (FileSplit) context.getInputSplit();
                Path path = split.getPath(); // current split path
                lvalue.set(path.getName());
                lkey.set(key.get());
                context.write(lkey, lvalue);

                first = false;
            }

            String line = value.toString();
            if (!line.isEmpty()) {
                Instance instance = converter.convert(line);
                double prediction = forest.classify(dataset, rng, instance);
                lkey.set(dataset.getLabel(instance));
                lvalue.set(Double.toString(prediction));
                context.write(lkey, lvalue);
            }
        }
    }
}