org.apache.mahout.classifier.rbm.test.TestRBMClassifierJob.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.rbm.test.TestRBMClassifierJob.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.classifier.rbm.test;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.rbm.RBMClassifier;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * The Class TestRBMClassifierJob which runs the tests in map/reduce or locally multithreaded.
 */
public class TestRBMClassifierJob extends AbstractJob {

    /** The Constant log. */
    private static final Logger log = LoggerFactory.getLogger(TestRBMClassifierJob.class);

    /**
     * The main method.
     *
     * @param args the arguments
     * @throws Exception the exception
     */
    public static void main(String[] args) throws Exception {
        ToolRunner.run(new Configuration(), new TestRBMClassifierJob(), args);
    }

    private int iterations;

    /* (non-Javadoc)
     * @see org.apache.hadoop.util.Tool#run(java.lang.String[])
     */
    @Override
    public int run(String[] args) throws Exception {
        addInputOption();
        addOption("model", "m", "The path to the model built during training", true);
        addOption("labelcount", "lc", "total count of labels existent in the training set", true);
        addOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION, "max",
                "least number of stable iterations in classification layer when classifying", "10");
        addOption(new DefaultOptionBuilder().withLongName(DefaultOptionCreator.MAPREDUCE_METHOD).withRequired(false)
                .withDescription("Run tests with map/reduce").withShortName("mr").create());

        Map<String, String> parsedArgs = parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }

        int labelcount = Integer.parseInt(getOption("labelcount"));
        iterations = Integer.parseInt(getOption("maxIter"));

        //check models existence
        Path model = new Path(parsedArgs.get("--model"));
        if (!model.getFileSystem(getConf()).exists(model)) {
            log.error("Model file does not exist!");
            return -1;
        }

        //create the list of all labels
        List<String> lables = new ArrayList<String>();
        for (int i = 0; i < labelcount; i++)
            lables.add(String.valueOf(i));

        FileSystem fs = getInputPath().getFileSystem(getConf());
        ResultAnalyzer analyzer = new ResultAnalyzer(lables, "-1");
        //initiate the paths to the test batches
        Path[] batches;
        if (fs.isFile(getInputPath()))
            batches = new Path[] { getInputPath() };
        else {
            FileStatus[] stati = fs.listStatus(getInputPath());
            batches = new Path[stati.length];
            for (int i = 0; i < stati.length; i++) {
                batches[i] = stati[i].getPath();
            }
        }

        if (hasOption("mapreduce"))
            HadoopUtil.delete(getConf(), getTempPath("testresults"));

        for (Path input : batches) {
            if (hasOption("mapreduce")) {
                HadoopUtil.cacheFiles(model, getConf());
                //the output key is the expected value, the output value are the scores for all the labels
                Job testJob = prepareJob(input, getTempPath("testresults"), SequenceFileInputFormat.class,
                        TestRBMClassifierMapper.class, IntWritable.class, VectorWritable.class,
                        SequenceFileOutputFormat.class);
                testJob.getConfiguration().set("maxIter", String.valueOf(iterations));
                testJob.waitForCompletion(true);

                //loop over the results and create the confusion matrix
                SequenceFileDirIterable<IntWritable, VectorWritable> dirIterable = new SequenceFileDirIterable<IntWritable, VectorWritable>(
                        getTempPath("testresults"), PathType.LIST, PathFilters.partFilter(), getConf());

                analyzeResults(dirIterable, analyzer);

            } else {
                //test job locally
                runTestsLocally(model, analyzer, input);
            }
        }

        //output the result of the tests
        log.info("RBMClassifier Results: {}", analyzer);

        //stop all running threads
        if (executor != null)
            executor.shutdownNow();
        return 0;
    }

    /** The executor. */
    private ExecutorService executor;

    /** The tasks. */
    List<RBMClassifierCall> tasks;

    /**
     * Analyze results locally.
     *
     * @param model the model
     * @param analyzer the analyzer
     * @param input the input
     * @throws IOException Signals that an I/O exception has occurred.
     * @throws InterruptedException the interrupted exception
     * @throws ExecutionException the execution exception
     */
    private void runTestsLocally(Path model, ResultAnalyzer analyzer, Path input)
            throws IOException, InterruptedException, ExecutionException {
        int testsize = 0;
        //maximum number of threads that are used, I think 20 is ok
        int threadCount = 20;
        RBMClassifier rbmCl = RBMClassifier.materialize(model, getConf());
        //initialize the executor if not already done
        if (executor == null)
            executor = Executors.newFixedThreadPool(threadCount);
        //initialize the tasks, which are run by the executor
        if (tasks == null)
            tasks = new ArrayList<RBMClassifierCall>();

        for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(input,
                getConf())) {
            //prepare the tasks
            if (tasks.size() < threadCount)
                tasks.add(new RBMClassifierCall(rbmCl.clone(), record.getSecond().get(), record.getFirst().get(),
                        iterations));
            else {
                tasks.get(testsize % threadCount).input = record.getSecond().get();
                tasks.get(testsize % threadCount).label = record.getFirst().get();
            }

            //run the tasks
            if (testsize % threadCount == threadCount - 1) {
                List<Future<Pair<Integer, Vector>>> futureResults = executor.invokeAll(tasks);
                //analyze results
                for (int i = 0; i < futureResults.size(); i++) {
                    int bestIdx = Integer.MIN_VALUE;
                    double bestScore = Long.MIN_VALUE;
                    Pair<Integer, Vector> pair = futureResults.get(i).get();
                    for (Vector.Element element : pair.getSecond()) {
                        if (element.get() > bestScore) {
                            bestScore = element.get();
                            bestIdx = element.index();
                        }
                    }
                    if (bestIdx != Integer.MIN_VALUE) {
                        ClassifierResult classifierResult = new ClassifierResult(String.valueOf(bestIdx),
                                bestScore);
                        analyzer.addInstance(String.valueOf(pair.getFirst()), classifierResult);
                    }
                }
            }

            testsize++;
        }

        //run and analyze remaining tasks
        if (testsize % 20 != 0) {
            List<Future<Pair<Integer, Vector>>> futureResults = executor
                    .invokeAll(tasks.subList(0, (testsize - 1) % 20));
            for (int i = 0; i < futureResults.size(); i++) {
                int bestIdx = Integer.MIN_VALUE;
                double bestScore = Long.MIN_VALUE;
                Pair<Integer, Vector> pair = futureResults.get(i).get();
                for (Vector.Element element : pair.getSecond()) {
                    if (element.get() > bestScore) {
                        bestScore = element.get();
                        bestIdx = element.index();
                    }
                }
                if (bestIdx != Integer.MIN_VALUE) {
                    ClassifierResult classifierResult = new ClassifierResult(String.valueOf(bestIdx), bestScore);
                    analyzer.addInstance(String.valueOf(pair.getFirst()), classifierResult);
                }
            }
        }
    }

    /**
      * Analyze results of M/R job.
      *
      * @param dirIterable the directory with the results
      * @param analyzer the analyzer
      */
    private void analyzeResults(SequenceFileDirIterable<IntWritable, VectorWritable> dirIterable,
            ResultAnalyzer analyzer) {
        for (Pair<IntWritable, VectorWritable> pair : dirIterable) {
            int bestIdx = Integer.MIN_VALUE;
            double bestScore = Long.MIN_VALUE;
            for (Vector.Element element : pair.getSecond().get()) {
                if (element.get() > bestScore) {
                    bestScore = element.get();
                    bestIdx = element.index();
                }
            }
            if (bestIdx != Integer.MIN_VALUE) {
                ClassifierResult classifierResult = new ClassifierResult(String.valueOf(bestIdx), bestScore);
                analyzer.addInstance(String.valueOf(pair.getFirst().get()), classifierResult);
            }

        }
    }

    /**
     * The Class RBMClassifier is the callable thread for the local classifying task.
     */
    class RBMClassifierCall implements Callable<Pair<Integer, Vector>> {

        /** The rbm cl. */
        private RBMClassifier rbmCl;

        /** The input. */
        private Vector input;

        /** The label. */
        private int label;

        /** The iterations. */
        private int iterations;

        /**
         * Instantiates a new rBM classifier call.
         *
         * @param rbmCl the rbm cl
         * @param input the input
         * @param label the label
         * @param iterations the number of iterations until stable
         */
        public RBMClassifierCall(RBMClassifier rbmCl, Vector input, int label, int iterations) {
            this.rbmCl = rbmCl;
            this.input = input;
            this.label = label;
            this.iterations = iterations;
        }

        /* (non-Javadoc)
         * @see java.util.concurrent.Callable#call()
         */
        @Override
        public Pair<Integer, Vector> call() throws Exception {
            return new Pair<Integer, Vector>(label, rbmCl.classify(input, iterations));
        }

    }
}