com.cloudera.knittingboar.utils.DatasetConverter.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.knittingboar.utils.DatasetConverter.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.cloudera.knittingboar.utils;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;

import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;

/**
 * Tool to convert 20newsgroups to the format for Knitting Boar - need to
 * convert multiple dirs of small files into larger splits containing multiples
 * types of records per line
 * 
 * 
 * 
 * 1. Download the canonical dataset from:
 * http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz 2.
 * Extract the dataset locally 3. Run the DatasetConverter process to merge the
 * smaller files into larger input files 4. edit "workDir" on line 44 in
 * com.cloudera.knittingboar.sgd.TestRunPOLRMasterAndNWorkers to reflect
 * location of input training data 5. Run unit test:
 * com.cloudera.knittingboar.sgd.TestRunPOLRMasterAndNWorkers
 * 
 * 
 * @author jpatterson
 * 
 */
public class DatasetConverter {

    private static void countWords(Analyzer analyzer, Collection<String> words, Reader in) throws IOException {

        // use the provided analyzer to tokenize the input stream
        TokenStream ts = analyzer.tokenStream("text", in);
        ts.addAttribute(CharTermAttribute.class);

        // for each word in the stream, minus non-word stuff, add word to collection
        while (ts.incrementToken()) {
            String s = ts.getAttribute(CharTermAttribute.class).toString();
            // System.out.print( " " + s );
            words.add(s);
        }

    }

    public static String ReadFullFile(Analyzer analyzer, String newsgroup_name, String file) throws IOException {

        String out = newsgroup_name + "\t";
        BufferedReader reader = null;
        // Collection<String> words

        Multiset<String> words = ConcurrentHashMultiset.create();

        try {
            reader = new BufferedReader(new FileReader(file));

            TokenStream ts = analyzer.tokenStream("text", reader);
            ts.addAttribute(CharTermAttribute.class);

            // for each word in the stream, minus non-word stuff, add word to
            // collection
            while (ts.incrementToken()) {
                String s = ts.getAttribute(CharTermAttribute.class).toString();
                out += s + " ";
            }

        } finally {
            if (reader != null) {
                reader.close();
            }
        }

        return out + "\n";

    }

    /**
     * Function to convert the 20Newsgroups from the standard 20,000 files in 20
     * directories to N files more appropriate for Knitting Boar
     * 
     * 1. Download the 20Newsgroups dataset from:
     * http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz 2.
     * Extract the dataset to a local dir 3. Run the DatasetConverter process to
     * merge the smaller files into larger input files check out
     * "TestConvert20NewsTestDataset" in the unit tests
     * 
     * 
     * @param inputBaseDir
     * @param outputBaseDir
     * @throws IOException
     */
    public static int ConvertNewsgroupsFromSingleFiles(String inputBaseDir, String outputBaseDir,
            int records_per_shard) throws IOException {

        Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

        File base = new File(inputBaseDir);

        // because OLR expects to get integer class IDs for the target variable
        // during training
        // we need a dictionary to convert the target variable (the newsgroup name)
        // to an integer, which is the newsGroup object
        List<File> files = new ArrayList<File>();
        for (File newsgroup : base.listFiles()) {
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }

        // mix up the files, helps training in OLR
        Collections.shuffle(files);
        System.out.printf("%d training files\n", files.size());

        double step = 0.0;
        int[] bumps = new int[] { 1, 2, 5 };

        BufferedWriter shard_writer = null;
        int shard_count = 0;
        int input_file_count = 0;
        Map<Integer, Integer> current_shard_rec_count = new HashMap<Integer, Integer>();

        try {

            File base_dir = new File(outputBaseDir);
            if (!base_dir.exists()) {
                base_dir.mkdirs();
            }

            File shard_file_0 = new File(outputBaseDir + "kboar-shard-" + shard_count + ".txt");

            if (shard_file_0.exists()) {
                shard_file_0.delete();
            }

            shard_file_0.createNewFile();
            shard_writer = new BufferedWriter(new FileWriter(shard_file_0));

            System.out.println("Starting: " + shard_file_0.toString());

            // ----- "reading and tokenzing the data" ---------
            for (File file : files) {

                input_file_count++;

                // identify newsgroup ----------------
                // convert newsgroup name to unique id
                // -----------------------------------
                String ng = file.getParentFile().getName();

                String file_contents = ReadFullFile(analyzer, file.getParentFile().getName(),
                        inputBaseDir + file.getParentFile().getName() + "/" + file.getName());

                shard_writer.write(file_contents);

                if (false == current_shard_rec_count.containsKey(shard_count)) {

                    System.out.println(".");
                    current_shard_rec_count.put(shard_count, 1);

                } else {
                    int c = current_shard_rec_count.get(shard_count);

                    current_shard_rec_count.put(shard_count, ++c);
                }

                if (current_shard_rec_count.get(shard_count) >= records_per_shard) {

                    shard_writer.flush();
                    shard_writer.close();

                    shard_count++;

                    shard_file_0 = new File(outputBaseDir + "kboar-shard-" + shard_count + ".txt");

                    System.out.println("Starting shard: " + "kboar-shard-" + shard_count + ".txt");

                    if (shard_file_0.exists()) {
                        shard_file_0.delete();
                    }

                    shard_file_0.createNewFile();

                    shard_writer = new BufferedWriter(new FileWriter(shard_file_0));

                }

                int bump = bumps[(int) Math.floor(step) % bumps.length];
                int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

                if (input_file_count % (bump * scale) == 0) {
                    step += 0.25;
                    System.out.printf("Files Converted: %10d , %10d \n", input_file_count,
                            current_shard_rec_count.get(shard_count));
                }

            } // for

        } finally {
            if (shard_writer != null) {
                shard_writer.flush();
                shard_writer.close();
            }
        }

        for (int x = 0; x < current_shard_rec_count.size(); x++) {

            System.out.println("> Shard " + x + " record count: " + current_shard_rec_count.get(x));

        }

        System.out.printf("> Total Files Converted: %10d \n", input_file_count);

        return input_file_count;

    }

    /**
     * Conversion Tool to break up the RCV1 dataset into smaller chunks for
     * various tests.
     * 
     * RCV1 Dataset:
     * 
     * https://github.com/JohnLangford/vowpal_wabbit/wiki/Rcv1-example
     * 
     * @param input_file
     * @param outputBaseDir
     * @param total_recs_to_extract
     * @param records_per_shard
     * @return
     * @throws IOException
     */
    public static int ExtractSubsetofRCV1V2ForTraining(String input_file, String outputBaseDir,
            int total_recs_to_extract, int records_per_shard) throws IOException {

        double step = 0.0;
        int[] bumps = new int[] { 1, 2, 5 };

        System.out.println("> ExtractSubsetofRCV1V2ForTraining: " + input_file);

        BufferedWriter shard_writer = null;
        int shard_count = 0;
        int input_file_count = 0;
        int line_count = 0;
        Map<Integer, Integer> current_shard_rec_count = new HashMap<Integer, Integer>();

        try {

            File base_dir = new File(outputBaseDir);
            if (!base_dir.exists()) {
                base_dir.mkdirs();
            }

            System.out.println(outputBaseDir + "rcv1-shard-" + shard_count + ".txt");

            File shard_file_0 = new File(outputBaseDir + "rcv1-shard-" + shard_count + ".txt");

            if (shard_file_0.exists()) {
                shard_file_0.delete();
            } else {

                System.out.println("no output file, creating...");

            }

            boolean bCreate = shard_file_0.createNewFile();

            System.out.println("file created: " + bCreate);

            shard_writer = new BufferedWriter(new FileWriter(shard_file_0));

            input_file_count++;

            BufferedReader reader = null;

            try {
                System.out.println("opening file for reading: input_file");
                reader = new BufferedReader(new FileReader(input_file));

                String line = reader.readLine();

                while (line != null && line.length() > 0) {

                    shard_writer.write(line + "\n");

                    line = reader.readLine();

                    if (false == current_shard_rec_count.containsKey(shard_count)) {

                        current_shard_rec_count.put(shard_count, 1);

                    } else {
                        int c = current_shard_rec_count.get(shard_count);

                        current_shard_rec_count.put(shard_count, ++c);
                    }

                    line_count++;

                    if (total_recs_to_extract <= line_count) {
                        break;
                    }

                    if (current_shard_rec_count.get(shard_count) >= records_per_shard) {

                        shard_writer.flush();
                        shard_writer.close();
                        shard_count++;

                        shard_file_0 = new File(outputBaseDir + "rcv1-shard-" + shard_count + ".txt");

                        System.out.println("Starting shard: " + "rcv1-shard-" + shard_count + ".txt");

                        if (shard_file_0.exists()) {
                            shard_file_0.delete();
                        }

                        shard_file_0.createNewFile();

                        shard_writer = new BufferedWriter(new FileWriter(shard_file_0));

                    }

                    int bump = bumps[(int) Math.floor(step) % bumps.length];
                    int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

                    if (input_file_count % (bump * scale) == 0) {
                        step += 0.25;
                        System.out.printf("Files Converted: %10d , %10d \n", input_file_count,
                                current_shard_rec_count.get(shard_count));
                    }

                }

            } finally {
                reader.close();
            }

        } catch (Exception e) {

            System.out.println(e);

        } finally {
            shard_writer.flush();
            shard_writer.close();
        }

        for (int x = 0; x < current_shard_rec_count.size(); x++) {

            System.out.println("> Shard " + x + " record count: " + current_shard_rec_count.get(x));

        }

        System.out.printf("> Total Files Converted: %10d \n", input_file_count);

        return input_file_count;
    }

}