org.apache.mahout.df.mapreduce.partial.Step2Mapper.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.df.mapreduce.partial.Step2Mapper.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.df.mapreduce.partial;

import java.io.IOException;
import java.net.URI;

import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.df.callback.SingleTreePredictions;
import org.apache.mahout.df.data.DataConverter;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.mapreduce.Builder;
import org.apache.mahout.df.mapreduce.MapredOutput;
import org.apache.mahout.df.node.Node;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Second step of PartialBuilder. Using the trees of the first step, computes the oob predictions for each
 * tree, except those of its own partition, on all instancesof the partition.
 */
public class Step2Mapper extends Mapper<LongWritable, Text, TreeID, MapredOutput> {

    private static final Logger log = LoggerFactory.getLogger(Step2Mapper.class);

    private TreeID[] keys;

    private Node[] trees;

    private SingleTreePredictions[] callbacks;

    private DataConverter converter;

    private int partition = -1;

    /** num treated instances */
    private int instanceId;

    @Override
    protected void setup(Context context) throws IOException, InterruptedException {
        Configuration conf = context.getConfiguration();

        // get the cached files' paths
        URI[] files = DistributedCache.getCacheFiles(conf);

        log.info("DistributedCache.getCacheFiles(): {}", ArrayUtils.toString(files));

        if ((files == null) || (files.length < 2)) {
            throw new IllegalArgumentException("missing paths from the DistributedCache");
        }

        Path datasetPath = new Path(files[0].getPath());
        Dataset dataset = Dataset.load(conf, datasetPath);

        int numMaps = Builder.getNumMaps(conf);
        int p = conf.getInt("mapred.task.partition", -1);

        // total number of trees in the forest
        int numTrees = Builder.getNbTrees(conf);
        if (numTrees == -1) {
            throw new IllegalArgumentException("numTrees not found !");
        }

        int nbConcerned = nbConcerned(numMaps, numTrees, p);
        keys = new TreeID[nbConcerned];
        trees = new Node[nbConcerned];

        Path forestPath = new Path(files[1].getPath());
        FileSystem fs = forestPath.getFileSystem(conf);
        int numInstances = InterResults.load(fs, forestPath, numMaps, numTrees, p, keys, trees);

        log.debug("partition: {} numInstances: {}", p, numInstances);
        configure(p, dataset, keys, trees, numInstances);
    }

    /**
     * Compute the number of trees that need to classify the instances of this mapper's partition
     * 
     * @param numMaps
     *          total number of map tasks
     * @param numTrees
     *          total number of trees in the forest
     * @param partition
     *          mapper's partition
     * @return
     */
    public static int nbConcerned(int numMaps, int numTrees, int partition) {
        if (partition < 0) {
            throw new IllegalArgumentException("partition < 0");
        }
        // the trees of the mapper's partition are not concerned
        return numTrees - Step1Mapper.nbTrees(numMaps, numTrees, partition);
    }

    /**
     * Useful for testing. Configures the mapper without using a JobConf<br>
     * TODO we don't need the keys partitions, the tree ids should suffice
     * 
     * @param partition
     *          mapper's partition
     * @param dataset
     * @param keys
     *          keys returned by the first step
     * @param trees
     *          trees returned by the first step
     * @param numInstances
     *          number of instances in the mapper's partition
     */
    public void configure(int partition, Dataset dataset, TreeID[] keys, Node[] trees, int numInstances) {
        this.partition = partition;
        if (partition < 0) {
            throw new IllegalArgumentException("Wrong partition id : " + partition);
        }

        converter = new DataConverter(dataset);

        if (keys.length != trees.length) {
            throw new IllegalArgumentException("keys.length != trees.length");
        }
        int nbConcerned = keys.length;

        this.keys = keys;
        this.trees = trees;

        // make sure the trees are not from this partition
        for (TreeID key : keys) {
            if (key.partition() == partition) {
                throw new IllegalArgumentException("a tree from this partition was found !");
            }
        }

        // init the callbacks
        callbacks = new SingleTreePredictions[nbConcerned];
        for (int index = 0; index < nbConcerned; index++) {
            callbacks[index] = new SingleTreePredictions(numInstances);
        }

    }

    @Override
    protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {

        Instance instance = converter.convert(instanceId, value.toString());

        for (int index = 0; index < keys.length; index++) {
            int prediction = trees[index].classify(instance);
            callbacks[index].prediction(index, instanceId, prediction);
        }

        instanceId++;
    }

    @Override
    protected void cleanup(Context context) throws IOException, InterruptedException {
        for (int index = 0; index < keys.length; index++) {
            TreeID key = new TreeID(partition, keys[index].treeId());
            context.write(key, new MapredOutput(callbacks[index].getPredictions()));
        }
    }

}