Example usage for org.apache.hadoop.mapreduce.task TaskAttemptContextImpl TaskAttemptContextImpl

List of usage examples for org.apache.hadoop.mapreduce.task TaskAttemptContextImpl TaskAttemptContextImpl

Introduction

In this page you can find the example usage for org.apache.hadoop.mapreduce.task TaskAttemptContextImpl TaskAttemptContextImpl.

Prototype

public TaskAttemptContextImpl(Configuration conf, TaskAttemptID taskId) 

Source Link

Usage

From source file:io.druid.data.input.parquet.DruidParquetInputFormatTest.java

License:Apache License

@Test
public void test() throws IOException, InterruptedException {
    Configuration conf = new Configuration();
    Job job = Job.getInstance(conf);//from   w  ww.  j a  va  2  s. c o  m

    HadoopDruidIndexerConfig config = HadoopDruidIndexerConfig
            .fromFile(new File("example/wikipedia_hadoop_parquet_job.json"));

    config.intoConfiguration(job);

    File testFile = new File("example/wikipedia_list.parquet");
    Path path = new Path(testFile.getAbsoluteFile().toURI());
    FileSplit split = new FileSplit(path, 0, testFile.length(), null);

    InputFormat inputFormat = ReflectionUtils.newInstance(DruidParquetInputFormat.class,
            job.getConfiguration());

    TaskAttemptContext context = new TaskAttemptContextImpl(job.getConfiguration(), new TaskAttemptID());
    RecordReader reader = inputFormat.createRecordReader(split, context);

    reader.initialize(split, context);

    reader.nextKeyValue();

    GenericRecord data = (GenericRecord) reader.getCurrentValue();

    // field not read, should return null
    assertEquals(data.get("added"), null);

    assertEquals(data.get("page"), new Utf8("Gypsy Danger"));

    reader.close();
}

From source file:io.druid.data.input.parquet.DruidParquetInputTest.java

License:Apache License

private GenericRecord getFirstRecord(Job job, String parquetPath) throws IOException, InterruptedException {
    File testFile = new File(parquetPath);
    Path path = new Path(testFile.getAbsoluteFile().toURI());
    FileSplit split = new FileSplit(path, 0, testFile.length(), null);

    DruidParquetInputFormat inputFormat = ReflectionUtils.newInstance(DruidParquetInputFormat.class,
            job.getConfiguration());/*  ww  w  .j  av a 2 s  .c  om*/
    TaskAttemptContext context = new TaskAttemptContextImpl(job.getConfiguration(), new TaskAttemptID());

    try (RecordReader reader = inputFormat.createRecordReader(split, context)) {

        reader.initialize(split, context);
        reader.nextKeyValue();
        return (GenericRecord) reader.getCurrentValue();
    }
}

From source file:io.druid.data.input.parquet.DruidParquetInputTest.java

License:Apache License

private List<InputRow> getAllRows(String configPath) throws IOException, InterruptedException {
    HadoopDruidIndexerConfig config = HadoopDruidIndexerConfig.fromFile(new File(configPath));
    Job job = Job.getInstance(new Configuration());
    config.intoConfiguration(job);/*from   w w  w.j a va  2 s .  c o m*/

    File testFile = new File(((StaticPathSpec) config.getPathSpec()).getPaths());
    Path path = new Path(testFile.getAbsoluteFile().toURI());
    FileSplit split = new FileSplit(path, 0, testFile.length(), null);

    DruidParquetInputFormat inputFormat = ReflectionUtils.newInstance(DruidParquetInputFormat.class,
            job.getConfiguration());
    TaskAttemptContext context = new TaskAttemptContextImpl(job.getConfiguration(), new TaskAttemptID());

    try (RecordReader reader = inputFormat.createRecordReader(split, context)) {
        List<InputRow> records = Lists.newArrayList();
        InputRowParser parser = config.getParser();

        reader.initialize(split, context);
        while (reader.nextKeyValue()) {
            reader.nextKeyValue();
            GenericRecord data = (GenericRecord) reader.getCurrentValue();
            records.add(parser.parse(data));
        }

        return records;
    }
}

From source file:it.crs4.pydoop.mapreduce.pipes.TestPipeApplication.java

License:Apache License

/**
 * test PipesMapRunner    test the transfer data from reader
 *
 * @throws Exception/*from  ww w  . jav  a  2 s .  c  o  m*/
 */
@Test
public void testRunner() throws Exception {
    // clean old password files
    File[] psw = cleanTokenPasswordFile();
    try {
        JobID jobId = new JobID("201408272347", 0);
        TaskID taskId = new TaskID(jobId, TaskType.MAP, 0);
        TaskAttemptID taskAttemptid = new TaskAttemptID(taskId, 0);

        Job job = new Job(new Configuration());
        job.setJobID(jobId);
        Configuration conf = job.getConfiguration();
        conf.set(Submitter.IS_JAVA_RR, "true");
        conf.set(MRJobConfig.TASK_ATTEMPT_ID, taskAttemptid.toString());
        job.setInputFormatClass(DummyInputFormat.class);
        FileSystem fs = new RawLocalFileSystem();
        fs.setConf(conf);

        DummyInputFormat input_format = new DummyInputFormat();
        List<InputSplit> isplits = input_format.getSplits(job);

        InputSplit isplit = isplits.get(0);

        TaskAttemptContextImpl tcontext = new TaskAttemptContextImpl(conf, taskAttemptid);

        RecordReader<FloatWritable, NullWritable> rReader = input_format.createRecordReader(isplit, tcontext);

        TestMapContext context = new TestMapContext(conf, taskAttemptid, rReader, null, null, null, isplit);
        // stub for client
        File fCommand = getFileCommand("it.crs4.pydoop.mapreduce.pipes.PipeApplicationRunnableStub");
        conf.set(MRJobConfig.CACHE_LOCALFILES, fCommand.getAbsolutePath());
        // token for authorization
        Token<AMRMTokenIdentifier> token = new Token<AMRMTokenIdentifier>("user".getBytes(),
                "password".getBytes(), new Text("kind"), new Text("service"));
        TokenCache.setJobToken(token, job.getCredentials());
        conf.setBoolean(MRJobConfig.SKIP_RECORDS, true);
        PipesMapper<FloatWritable, NullWritable, IntWritable, Text> mapper = new PipesMapper<FloatWritable, NullWritable, IntWritable, Text>(
                context);

        initStdOut(conf);
        mapper.run(context);
        String stdOut = readStdOut(conf);

        // test part of translated data. As common file for client and test -
        // clients stdOut
        // check version
        assertTrue(stdOut.contains("CURRENT_PROTOCOL_VERSION:0"));
        // check key and value classes
        assertTrue(stdOut.contains("Key class:org.apache.hadoop.io.FloatWritable"));
        assertTrue(stdOut.contains("Value class:org.apache.hadoop.io.NullWritable"));
        // test have sent all data from reader
        assertTrue(stdOut.contains("value:0.0"));
        assertTrue(stdOut.contains("value:9.0"));

    } finally {
        if (psw != null) {
            // remove password files
            for (File file : psw) {
                file.deleteOnExit();
            }
        }
    }
}

From source file:it.crs4.pydoop.mapreduce.pipes.TestPipeApplication.java

License:Apache License

/**
 * test org.apache.hadoop.mapreduce.pipes.Application
 * test a internal functions: //w  w w  . j  av a2 s. c  o m
 *     MessageType.REGISTER_COUNTER,  INCREMENT_COUNTER, STATUS, PROGRESS...
 *
 * @throws Throwable
 */

@Test
public void testApplication() throws Throwable {

    System.err.println("testApplication");

    File[] psw = cleanTokenPasswordFile();
    try {
        JobID jobId = new JobID("201408272347", 0);
        TaskID taskId = new TaskID(jobId, TaskType.MAP, 0);
        TaskAttemptID taskAttemptid = new TaskAttemptID(taskId, 0);

        Job job = new Job(new Configuration());
        job.setJobID(jobId);
        Configuration conf = job.getConfiguration();
        conf.set(MRJobConfig.TASK_ATTEMPT_ID, taskAttemptid.toString());
        FileSystem fs = new RawLocalFileSystem();
        fs.setConf(conf);

        File fCommand = getFileCommand("it.crs4.pydoop.mapreduce.pipes.PipeApplicationStub");
        //getFileCommand("it.crs4.pydoop.mapreduce.pipes.PipeApplicationRunnableStub");
        conf.set(MRJobConfig.CACHE_LOCALFILES, fCommand.getAbsolutePath());
        System.err.println("fCommand" + fCommand.getAbsolutePath());

        Token<AMRMTokenIdentifier> token = new Token<AMRMTokenIdentifier>("user".getBytes(),
                "password".getBytes(), new Text("kind"), new Text("service"));
        TokenCache.setJobToken(token, job.getCredentials());
        conf.setBoolean(MRJobConfig.SKIP_RECORDS, true);

        TestReporter reporter = new TestReporter();
        DummyInputFormat input_format = new DummyInputFormat();
        List<InputSplit> isplits = input_format.getSplits(job);
        InputSplit isplit = isplits.get(0);
        TaskAttemptContextImpl tcontext = new TaskAttemptContextImpl(conf, taskAttemptid);

        DummyRecordReader reader = (DummyRecordReader) input_format.createRecordReader(isplit, tcontext);

        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(Text.class);

        RecordWriter<IntWritable, Text> writer = new TestRecordWriter(
                new FileOutputStream(workSpace.getAbsolutePath() + File.separator + "outfile"));

        MapContextImpl<IntWritable, Text, IntWritable, Text> context = new MapContextImpl<IntWritable, Text, IntWritable, Text>(
                conf, taskAttemptid, null, writer, null, reporter, null);

        System.err.println("ready to launch application");
        Application<IntWritable, Text, IntWritable, Text> application = new Application<IntWritable, Text, IntWritable, Text>(
                context, reader);
        System.err.println("done");

        application.getDownlink().flush();
        application.getDownlink().mapItem(new IntWritable(3), new Text("txt"));
        application.getDownlink().flush();
        application.waitForFinish();

        // test getDownlink().mapItem();
        String stdOut = readStdOut(conf);
        assertTrue(stdOut.contains("key:3"));
        assertTrue(stdOut.contains("value:txt"));

        assertEquals(0.0, context.getProgress(), 0.01);
        assertNotNull(context.getCounter("group", "name"));

        // test status MessageType.STATUS
        assertEquals(context.getStatus(), "PROGRESS");
        // check MessageType.PROGRESS
        assertEquals(0.55f, reader.getProgress(), 0.001);
        application.getDownlink().close();
        // test MessageType.OUTPUT
        stdOut = readFile(new File(workSpace.getAbsolutePath() + File.separator + "outfile"));
        assertTrue(stdOut.contains("key:123"));
        assertTrue(stdOut.contains("value:value"));
        try {
            // try to abort
            application.abort(new Throwable());
            fail();
        } catch (IOException e) {
            // abort works ?
            assertEquals("pipe child exception", e.getMessage());
        }
    } finally {
        if (psw != null) {
            // remove password files
            for (File file : psw) {
                file.deleteOnExit();
            }
        }
    }
}

From source file:it.crs4.pydoop.mapreduce.pipes.TestPipeApplication.java

License:Apache License

/**
 * test org.apache.hadoop.mapreduce.pipes.PipesReducer
 * test the transfer of data: key and value
 *
 * @throws Exception/*from   w  w w .j  a va 2s.com*/
 */
@Test
public void testPipesReducer() throws Exception {
    System.err.println("testPipesReducer");

    File[] psw = cleanTokenPasswordFile();
    try {
        JobID jobId = new JobID("201408272347", 0);
        TaskID taskId = new TaskID(jobId, TaskType.MAP, 0);
        TaskAttemptID taskAttemptid = new TaskAttemptID(taskId, 0);

        Job job = new Job(new Configuration());
        job.setJobID(jobId);
        Configuration conf = job.getConfiguration();
        conf.set(MRJobConfig.TASK_ATTEMPT_ID, taskAttemptid.toString());
        FileSystem fs = new RawLocalFileSystem();
        fs.setConf(conf);

        File fCommand = getFileCommand("it.crs4.pydoop.mapreduce.pipes.PipeReducerStub");
        conf.set(MRJobConfig.CACHE_LOCALFILES, fCommand.getAbsolutePath());
        System.err.println("fCommand" + fCommand.getAbsolutePath());

        Token<AMRMTokenIdentifier> token = new Token<AMRMTokenIdentifier>("user".getBytes(),
                "password".getBytes(), new Text("kind"), new Text("service"));
        TokenCache.setJobToken(token, job.getCredentials());
        conf.setBoolean(MRJobConfig.SKIP_RECORDS, true);

        TestReporter reporter = new TestReporter();
        DummyInputFormat input_format = new DummyInputFormat();
        List<InputSplit> isplits = input_format.getSplits(job);
        InputSplit isplit = isplits.get(0);
        TaskAttemptContextImpl tcontext = new TaskAttemptContextImpl(conf, taskAttemptid);

        RecordWriter<IntWritable, Text> writer = new TestRecordWriter(
                new FileOutputStream(workSpace.getAbsolutePath() + File.separator + "outfile"));

        BooleanWritable bw = new BooleanWritable(true);
        List<Text> texts = new ArrayList<Text>();
        texts.add(new Text("first"));
        texts.add(new Text("second"));
        texts.add(new Text("third"));

        DummyRawKeyValueIterator kvit = new DummyRawKeyValueIterator();

        ReduceContextImpl<BooleanWritable, Text, IntWritable, Text> context = new ReduceContextImpl<BooleanWritable, Text, IntWritable, Text>(
                conf, taskAttemptid, kvit, null, null, writer, null, null, null, BooleanWritable.class,
                Text.class);

        PipesReducer<BooleanWritable, Text, IntWritable, Text> reducer = new PipesReducer<BooleanWritable, Text, IntWritable, Text>();
        reducer.setup(context);

        initStdOut(conf);
        reducer.reduce(bw, texts, context);
        reducer.cleanup(context);
        String stdOut = readStdOut(conf);

        // test data: key
        assertTrue(stdOut.contains("reducer key :true"));
        // and values
        assertTrue(stdOut.contains("reduce value  :first"));
        assertTrue(stdOut.contains("reduce value  :second"));
        assertTrue(stdOut.contains("reduce value  :third"));

    } finally {
        if (psw != null) {
            // remove password files
            for (File file : psw) {
                file.deleteOnExit();
            }
        }
    }

}

From source file:it.crs4.pydoop.mapreduce.pipes.TestPipesNonJavaInputFormat.java

License:Apache License

/**
 *  test PipesNonJavaInputFormat//from w ww  . ja v a 2s. c  om
  */

@Test
public void testFormat() throws IOException, InterruptedException {
    JobID jobId = new JobID("201408272347", 0);
    TaskID taskId = new TaskID(jobId, TaskType.MAP, 0);
    TaskAttemptID taskAttemptid = new TaskAttemptID(taskId, 0);

    Job job = new Job(new Configuration());
    job.setJobID(jobId);
    Configuration conf = job.getConfiguration();

    TaskAttemptContextImpl tcontext = new TaskAttemptContextImpl(conf, taskAttemptid);

    PipesNonJavaInputFormat input_format = new PipesNonJavaInputFormat();

    DummyRecordReader reader = (DummyRecordReader) input_format.createRecordReader(new FileSplit(), tcontext);
    assertEquals(0.0f, reader.getProgress(), 0.001);

    // input and output files
    File input1 = new File(workSpace + File.separator + "input1");
    if (!input1.getParentFile().exists()) {
        Assert.assertTrue(input1.getParentFile().mkdirs());
    }

    if (!input1.exists()) {
        Assert.assertTrue(input1.createNewFile());
    }

    File input2 = new File(workSpace + File.separator + "input2");
    if (!input2.exists()) {
        Assert.assertTrue(input2.createNewFile());
    }

    // THIS fill fail without hdfs support.
    // // set data for splits
    // conf.set(org.apache.hadoop.mapreduce.lib.input.FileInputFormat.INPUT_DIR,
    //          StringUtils.escapeString(input1.getAbsolutePath()) + ","
    //          + StringUtils.escapeString(input2.getAbsolutePath()));
    // List<InputSplit> splits = input_format.getSplits(job);
    // assertTrue(splits.size() >= 2);

    PipesNonJavaInputFormat.PipesDummyRecordReader dummyRecordReader = new PipesNonJavaInputFormat.PipesDummyRecordReader(
            new FileSplit(), tcontext);
    // empty dummyRecordReader
    assertEquals(0.0, dummyRecordReader.getProgress(), 0.001);
    // test method next
    assertTrue(dummyRecordReader.next(new FloatWritable(2.0f), NullWritable.get()));
    assertEquals(2.0, dummyRecordReader.getProgress(), 0.001);
    dummyRecordReader.close();
}

From source file:mlbench.bayes.test.BayesTest.java

License:Apache License

@SuppressWarnings("deprecation")
public static void main(String[] args) throws MPI_D_Exception, IOException, MPIException {
    parseArgs(args);/*w w  w .j  a v a2 s  .c  o m*/
    HashMap<String, String> conf = new HashMap<String, String>();
    initConf(conf);
    MPI_D.Init(args, MPI_D.Mode.Common, conf);

    if (MPI_D.COMM_BIPARTITE_O != null) {
        rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O);
        size = MPI_D.Comm_size(MPI_D.COMM_BIPARTITE_O);
        NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, config);
        classifier = new StandardNaiveBayesClassifier(model);

        MPI_D.COMM_BIPARTITE_O.Barrier();
        FileSplit[] inputs = DataMPIUtil.HDFSDataLocalLocator.getTaskInputs(MPI_D.COMM_BIPARTITE_O,
                (JobConf) config, inDir, rank);

        for (int i = 0; i < inputs.length; i++) {
            FileSplit fsplit = inputs[i];
            SequenceFileRecordReader<Text, VectorWritable> kvrr = new SequenceFileRecordReader<>(config,
                    fsplit);
            Text key = kvrr.createKey();
            VectorWritable value = kvrr.createValue();
            while (kvrr.next(key, value)) {
                Vector result = classifier.classifyFull(value.get());
                MPI_D.Send(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result));
            }
        }
    } else if (MPI_D.COMM_BIPARTITE_A != null) {
        int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_A);
        config.set(MAPRED_OUTPUT_DIR, outDir);
        config.set("mapred.task.id", DataMPIUtil.getHadoopTaskAttemptID().toString().toString());
        ((JobConf) config).setOutputKeyClass(Text.class);
        ((JobConf) config).setOutputValueClass(VectorWritable.class);
        TaskAttemptContext taskContext = new TaskAttemptContextImpl(config,
                DataMPIUtil.getHadoopTaskAttemptID());
        SequenceFileOutputFormat<Text, VectorWritable> outfile = new SequenceFileOutputFormat<>();
        FileSystem fs = FileSystem.get(config);

        Path output = new Path(config.get(MAPRED_OUTPUT_DIR));
        FileOutputCommitter fcommitter = new FileOutputCommitter(output, taskContext);
        RecordWriter<Text, VectorWritable> outrw = null;
        try {
            fcommitter.setupJob(taskContext);
            outrw = outfile.getRecordWriter(fs, (JobConf) config, getOutputName(rank), null);
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("ERROR: Please set the HDFS configuration properly\n");
            System.exit(-1);
        }
        Text key = null;
        VectorWritable point = null;
        Vector vector = null;
        Object[] vals = MPI_D.Recv();
        while (vals != null) {
            key = (Text) vals[0];
            point = (VectorWritable) vals[1];
            if (key != null && point != null) {
                vector = point.get();
                outrw.write(key, new VectorWritable(vector));
            }
            vals = MPI_D.Recv();
        }
        outrw.close(null);
        if (fcommitter.needsTaskCommit(taskContext)) {
            fcommitter.commitTask(taskContext);
        }

        MPI_D.COMM_BIPARTITE_A.Barrier();
        if (rank == 0) {
            // load the labels
            Map<Integer, String> labelMap = BayesUtils.readLabelIndex(config, labPath);
            // loop over the results and create the confusion matrix
            SequenceFileDirIterable<Text, VectorWritable> dirIterable = new SequenceFileDirIterable<Text, VectorWritable>(
                    output, PathType.LIST, PathFilters.partFilter(), config);
            ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT");
            analyzeResults(labelMap, dirIterable, analyzer);
        }
    }
    MPI_D.Finalize();
}

From source file:mlbench.bayes.train.IndexInstances.java

License:Apache License

@SuppressWarnings({ "deprecation" })
public static void main(String[] args) throws MPI_D_Exception, IOException, MPIException {
    parseArgs(args);//from www  .  ja v a  2 s .  c o  m
    HashMap<String, String> conf = new HashMap<String, String>();
    initConf(conf);
    MPI_D.Init(args, MPI_D.Mode.Common, conf);
    if (MPI_D.COMM_BIPARTITE_O != null) {
        rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O);

        if (rank == 0) {
            System.out.println(IndexInstances.class.getSimpleName() + " O start.");
            createLabelIndex(labPath);
        }

        HadoopUtil.cacheFiles(labPath, config);

        MPI_D.COMM_BIPARTITE_O.Barrier();

        OpenObjectIntHashMap<String> labelIndex = BayesUtils.readIndexFromCache(config);

        if (MPI_D.COMM_BIPARTITE_O != null) {
            // O communicator
            int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O);
            int size = MPI_D.Comm_size(MPI_D.COMM_BIPARTITE_O);
            FileSplit[] inputs = DataMPIUtil.HDFSDataLocalLocator.getTaskInputs(MPI_D.COMM_BIPARTITE_O,
                    (JobConf) config, inDir, rank);
            for (int i = 0; i < inputs.length; i++) {
                FileSplit fsplit = inputs[i];
                SequenceFileRecordReader<Text, VectorWritable> kvrr = new SequenceFileRecordReader<>(config,
                        fsplit);
                Text labelText = kvrr.createKey();
                VectorWritable instance = kvrr.createValue();
                while (kvrr.next(labelText, instance)) {
                    String label = SLASH.split(labelText.toString())[1];
                    if (labelIndex.containsKey(label)) {
                        MPI_D.Send(new IntWritable(labelIndex.get(label)), instance);
                    }
                }
            }
        }
    } else if (MPI_D.COMM_BIPARTITE_A != null) {
        int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_A);
        config.set(MAPRED_OUTPUT_DIR, outDir);
        config.set("mapred.task.id", DataMPIUtil.getHadoopTaskAttemptID().toString().toString());
        ((JobConf) config).setOutputKeyClass(IntWritable.class);
        ((JobConf) config).setOutputValueClass(VectorWritable.class);
        TaskAttemptContext taskContext = new TaskAttemptContextImpl(config,
                DataMPIUtil.getHadoopTaskAttemptID());
        SequenceFileOutputFormat<IntWritable, VectorWritable> outfile = new SequenceFileOutputFormat<>();
        FileSystem fs = FileSystem.get(config);

        Path output = new Path(config.get(MAPRED_OUTPUT_DIR));
        FileOutputCommitter fcommitter = new FileOutputCommitter(output, taskContext);
        RecordWriter<IntWritable, VectorWritable> outrw = null;
        try {
            fcommitter.setupJob(taskContext);
            outrw = outfile.getRecordWriter(fs, (JobConf) config, getOutputName(rank), null);
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("ERROR: Please set the HDFS configuration properly\n");
            System.exit(-1);
        }

        IntWritable key = null, newKey = null;
        VectorWritable point = null, newPoint = null;
        Vector vector = null;
        Object[] vals = MPI_D.Recv();
        while (vals != null) {
            newKey = (IntWritable) vals[0];
            newPoint = (VectorWritable) vals[1];
            if (key == null && point == null) {
            } else if (!key.equals(newKey)) {
                outrw.write(key, new VectorWritable(vector));
                vector = null;
            }
            if (vector == null) {
                vector = newPoint.get();
            } else {
                vector.assign(newPoint.get(), Functions.PLUS);
            }

            key = newKey;
            point = newPoint;
            vals = MPI_D.Recv();
        }
        if (newKey != null && newPoint != null) {
            outrw.write(key, new VectorWritable(vector));
        }

        outrw.close(null);
        if (fcommitter.needsTaskCommit(taskContext)) {
            fcommitter.commitTask(taskContext);
        }
    }

    MPI_D.Finalize();
}

From source file:mlbench.bayes.train.WeightSummer.java

License:Apache License

@SuppressWarnings("deprecation")
public static void main(String[] args) throws MPI_D_Exception, IOException, MPIException {
    parseArgs(args);/* w  w  w .  j  av a2 s . c om*/
    HashMap<String, String> conf = new HashMap<String, String>();
    initConf(conf);
    MPI_D.Init(args, MPI_D.Mode.Common, conf);
    if (MPI_D.COMM_BIPARTITE_O != null) {

        int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O);
        int size = MPI_D.Comm_size(MPI_D.COMM_BIPARTITE_O);
        FileSplit[] inputs = DataMPIUtil.HDFSDataLocalLocator.getTaskInputs(MPI_D.COMM_BIPARTITE_O,
                (JobConf) config, inDir, rank);
        Vector weightsPerFeature = null;
        Vector weightsPerLabel = new DenseVector(labNum);

        for (int i = 0; i < inputs.length; i++) {
            FileSplit fsplit = inputs[i];
            SequenceFileRecordReader<IntWritable, VectorWritable> kvrr = new SequenceFileRecordReader<>(config,
                    fsplit);
            IntWritable index = kvrr.createKey();
            VectorWritable value = kvrr.createValue();
            while (kvrr.next(index, value)) {
                Vector instance = value.get();
                if (weightsPerFeature == null) {
                    weightsPerFeature = new RandomAccessSparseVector(instance.size(),
                            instance.getNumNondefaultElements());
                }

                int label = index.get();
                weightsPerFeature.assign(instance, Functions.PLUS);
                weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
            }
        }
        if (weightsPerFeature != null) {
            MPI_D.Send(new Text(WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature));
            MPI_D.Send(new Text(WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel));
        }
    } else if (MPI_D.COMM_BIPARTITE_A != null) {
        int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_A);
        config.set(MAPRED_OUTPUT_DIR, outDirW);
        config.set("mapred.task.id", DataMPIUtil.getHadoopTaskAttemptID().toString().toString());
        ((JobConf) config).setOutputKeyClass(Text.class);
        ((JobConf) config).setOutputValueClass(VectorWritable.class);
        TaskAttemptContext taskContext = new TaskAttemptContextImpl(config,
                DataMPIUtil.getHadoopTaskAttemptID());
        SequenceFileOutputFormat<Text, VectorWritable> outfile = new SequenceFileOutputFormat<>();
        FileSystem fs = FileSystem.get(config);

        Path output = new Path(config.get(MAPRED_OUTPUT_DIR));
        FileOutputCommitter fcommitter = new FileOutputCommitter(output, taskContext);
        RecordWriter<Text, VectorWritable> outrw = null;
        try {
            fcommitter.setupJob(taskContext);
            outrw = outfile.getRecordWriter(fs, (JobConf) config, getOutputName(rank), null);
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("ERROR: Please set the HDFS configuration properly\n");
            System.exit(-1);
        }

        Text key = null, newKey = null;
        VectorWritable point = null, newPoint = null;
        Vector vector = null;
        Object[] vals = MPI_D.Recv();
        while (vals != null) {
            newKey = (Text) vals[0];
            newPoint = (VectorWritable) vals[1];
            if (key == null && point == null) {
            } else if (!key.equals(newKey)) {
                outrw.write(key, new VectorWritable(vector));
                vector = null;
            }
            if (vector == null) {
                vector = newPoint.get();
            } else {
                vector.assign(newPoint.get(), Functions.PLUS);
            }

            key = newKey;
            point = newPoint;
            vals = MPI_D.Recv();
        }
        if (newKey != null && newPoint != null) {
            outrw.write(key, new VectorWritable(vector));
        }

        outrw.close(null);
        if (fcommitter.needsTaskCommit(taskContext)) {
            fcommitter.commitTask(taskContext);
        }

        MPI_D.COMM_BIPARTITE_A.Barrier();
        if (rank == 0) {
            Path resOut = new Path(outDir);
            NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(new Path(outDir), config);
            naiveBayesModel.serialize(resOut, config);
        }
    }

    MPI_D.Finalize();
}