org.apache.mahout.df.mapred.partial.PartitionBugTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.df.mapred.partial.PartitionBugTest.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.df.mapred.partial;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Locale;
import java.util.Random;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.builder.TreeBuilder;
import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.data.Utils;
import org.apache.mahout.df.node.Node;

public class PartitionBugTest extends MahoutTestCase {
    private static final int numAttributes = 40;

    private static final int numInstances = 200;

    private static final int numTrees = 10;

    private static final int numMaps = 5;

    /**
     * Make sure that the correct instance ids are being computed
     * 
     * @throws Exception
     * 
     */
    public void testProcessOutput() throws Exception {
        Random rng = RandomUtils.getRandom();
        //long seed = rng.nextLong();

        // create a dataset large enough to be split up
        String descriptor = Utils.randomDescriptor(rng, numAttributes);
        double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);

        // each instance label is its index in the dataset
        int labelId = Utils.findLabel(descriptor);
        for (int index = 0; index < numInstances; index++) {
            source[index][labelId] = index;
        }

        // store the data into a file
        String[] sData = Utils.double2String(source);
        Path dataPath = Utils.writeDataToTestFile(sData);
        Dataset dataset = DataLoader.generateDataset(descriptor, sData);
        Data data = DataLoader.loadData(dataset, sData);

        JobConf jobConf = new JobConf();
        jobConf.setNumMapTasks(numMaps);

        // prepare a custom TreeBuilder that will classify each
        // instance with its own label (in this case its index in the dataset)
        TreeBuilder treeBuilder = new MockTreeBuilder();

        // disable the second step because we can test without it
        // and we won't be able to serialize the MockNode
        PartialBuilder.setStep2(jobConf, false);
        long seed = 1L;
        PartialSequentialBuilder builder = new PartialSequentialBuilder(treeBuilder, dataPath, dataset, seed,
                jobConf);

        // remove the output path (its only used for testing)
        Path outputPath = builder.getOutputPath(jobConf);
        FileSystem fs = outputPath.getFileSystem(jobConf);
        HadoopUtil.overwriteOutput(outputPath);

        builder.build(numTrees, new MockCallback(data));
    }

    /**
     * Assets that the instanceId are correct
     *
     */
    private static class MockCallback implements PredictionCallback {
        private final Data data;

        private MockCallback(Data data) {
            this.data = data;
        }

        @Override
        public void prediction(int treeId, int instanceId, int prediction) {
            // because of the bagging, prediction can be -1
            if (prediction == -1) {
                return;
            }

            assertEquals(String.format(Locale.ENGLISH, "treeId: %d, InstanceId: %d, Prediction: %d", treeId,
                    instanceId, prediction), data.get(instanceId).label, prediction);
        }

    }

    /**
     * Custom Leaf node that returns for each instance its own label
     * 
     */
    private static class MockLeaf extends Node {

        @Override
        public int classify(Instance instance) {
            return instance.label;
        }

        @Override
        protected String getString() {
            return "[MockLeaf]";
        }

        @Override
        public long maxDepth() {
            return 0;
        }

        @Override
        protected Type getType() {
            return Type.MOCKLEAF;
        }

        @Override
        public long nbNodes() {
            return 0;
        }

        @Override
        protected void writeNode(DataOutput out) throws IOException {
        }

        @Override
        public void readFields(DataInput in) throws IOException {
        }

    }

    private static class MockTreeBuilder implements TreeBuilder {

        @Override
        public Node build(Random rng, Data data) {
            return new MockLeaf();
        }

    }
}