Example usage for org.apache.mahout.vectorizer.encoders ConstantValueEncoder ConstantValueEncoder

List of usage examples for org.apache.mahout.vectorizer.encoders ConstantValueEncoder ConstantValueEncoder

Introduction

In this page you can find the example usage for org.apache.mahout.vectorizer.encoders ConstantValueEncoder ConstantValueEncoder.

Prototype

public ConstantValueEncoder(String name) 

Source Link

Usage

From source file:SimpleCsvExamples.java

License:Apache License

public static void main(String[] args) throws IOException {
    FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS];
    for (int i = 0; i < FIELDS; i++) {
        encoder[i] = new ConstantValueEncoder("v" + 1);
    }//from   www  .  j  av a  2  s .  c  o  m

    OnlineSummarizer[] s = new OnlineSummarizer[FIELDS];
    for (int i = 0; i < FIELDS; i++) {
        s[i] = new OnlineSummarizer();
    }
    long t0 = System.currentTimeMillis();
    Vector v = new DenseVector(1000);
    if ("--generate".equals(args[0])) {
        PrintWriter out = new PrintWriter(
                new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8));
        try {
            int n = Integer.parseInt(args[1]);
            for (int i = 0; i < n; i++) {
                Line x = Line.generate();
                out.println(x);
            }
        } finally {
            Closeables.close(out, false);
        }
    } else if ("--parse".equals(args[0])) {
        BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8);
        double total = 0;
        try {
            String line = in.readLine();
            while (line != null) {
                v.assign(0);
                Line x = new Line(line);
                for (int i = 0; i < FIELDS; i++) {
                    double z = x.getDouble(i);
                    total += z;
                    //s[i].add(x.getDouble(i));
                    encoder[i].addToVector(x.get(i), v);
                }
                line = in.readLine();
            }
        } finally {
            Closeables.close(in, true);
        }
        //      String separator = "";
        //      for (int i = 0; i < FIELDS; i++) {
        //        System.out.printf("%s%.3f", separator, s[i].getMean());
        //        separator = ",";
        //      }
        System.out.println("total: " + total);
    } else if ("--fast".equals(args[0])) {
        FastLineReader in = new FastLineReader(new FileInputStream(args[1]));
        double total = 0;
        try {
            FastLine line = in.read();
            while (line != null) {
                v.assign(0);
                for (int i = 0; i < FIELDS; i++) {
                    double z = line.getDouble(i);
                    total += z;
                    //s[i].add(z);
                    encoder[i].addToVector((byte[]) null, z, v);
                }
                line = in.read();
            }
        } finally {
            Closeables.close(in, true);
        }
        //      String separator = "";
        //      for (int i = 0; i < FIELDS; i++) {
        //        System.out.printf("%s%.3f", separator, s[i].getMean());
        //        separator = ",";
        //      }
        System.out.println("total: " + total);
    }
    System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0);
}

From source file:com.cloudera.knittingboar.records.RCV1RecordFactory.java

License:Apache License

public RCV1RecordFactory() {

    this.encoder = new ConstantValueEncoder("body_values");

}

From source file:com.cloudera.knittingboar.records.RCV1RecordFactory.java

License:Apache License

public static void ScanFile(String file, int debug_break_cnt) throws IOException {

    ConstantValueEncoder encoder_test = new ConstantValueEncoder("test");

    BufferedReader reader = null;
    // Collection<String> words
    int line_count = 0;

    Multiset<String> class_count = ConcurrentHashMultiset.create();
    Multiset<String> namespaces = ConcurrentHashMultiset.create();

    try {/*from  w  ww . j ava 2  s.  c  o  m*/
        // System.out.println( newsgroup );
        reader = new BufferedReader(new FileReader(file));

        String line = reader.readLine();

        while (line != null && line.length() > 0) {

            // shard_writer.write(line + "\n");
            // out += line;

            String[] parts = line.split(" ");

            // System.out.println( "Class: " + parts[0] );

            class_count.add(parts[0]);
            namespaces.add(parts[1]);

            line = reader.readLine();
            line_count++;

            Vector v = new RandomAccessSparseVector(FEATURES);

            for (int x = 2; x < parts.length; x++) {
                // encoder_test.addToVector(parts[x], v);
                // System.out.println( parts[x] );
                String[] feature = parts[x].split(":");
                int index = Integer.parseInt(feature[0]) % FEATURES;
                double val = Double.parseDouble(feature[1]);

                // System.out.println( feature[1] + " = " + val );

                if (index < FEATURES) {
                    v.set(index, val);
                } else {

                    System.out.println("Could Hash: " + index + " to " + (index % FEATURES));

                }

            }

            Utils.PrintVectorSectionNonZero(v, 10);
            System.out.println("###");

            if (line_count > debug_break_cnt) {
                break;
            }

        }

        System.out.println("Total Rec Count: " + line_count);

        System.out.println("-------------------- ");

        System.out.println("Classes");
        for (String word : class_count.elementSet()) {
            System.out.println("Class " + word + ": " + class_count.count(word) + " ");
        }

        System.out.println("-------------------- ");

        System.out.println("NameSpaces:");
        for (String word : namespaces.elementSet()) {
            System.out.println("Namespace " + word + ": " + namespaces.count(word) + " ");
        }

        /*
         * TokenStream ts = analyzer.tokenStream("text", reader);
         * ts.addAttribute(CharTermAttribute.class);
         * 
         * // for each word in the stream, minus non-word stuff, add word to
         * collection while (ts.incrementToken()) { String s =
         * ts.getAttribute(CharTermAttribute.class).toString();
         * //System.out.print( " " + s ); //words.add(s); out += s + " "; }
         */

    } finally {
        reader.close();
    }

    // return out + "\n";

}

From source file:com.cloudera.knittingboar.records.Test20NewsgroupsBookParsing.java

License:Apache License

public void test20NewsgroupsFileScan() throws IOException {

    // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;// ww  w.j a va2s  . c  o  m
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    Splitter onColon = Splitter.on(":").trimResults();
    // last line on p.269
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

    File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20-debug/");
    overallCounts = HashMultiset.create();

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average frequency 
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    // used to encode the number of lines in a message
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);
    Dictionary newsGroups = new Dictionary();

    // bottom of p.269 ------------------------------
    // because OLR expects to get integer class IDs for the target variable during training
    // we need a dictionary to convert the target variable (the newsgroup name)
    // to an integer, which is the newsGroup object
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
        newsGroups.intern(newsgroup.getName());
        System.out.println(">> " + newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    // mix up the files, helps training in OLR
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());

    // ----- p.270 ------------ "reading and tokenzing the data" ---------
    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));

        // identify newsgroup ----------------
        // convert newsgroup name to unique id
        // -----------------------------------
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();

        // check for line count header -------
        String line = reader.readLine();
        while (line != null && line.length() > 0) {

            // if this is a line that has a line count, let's pull that value out ------
            if (line.startsWith("Lines:")) {
                String count = Iterables.get(onColon.split(line), 1);
                try {
                    lineCount = Integer.parseInt(count);
                    averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
                } catch (NumberFormatException e) {
                    // if anything goes wrong in parse: just use the avg count
                    lineCount = averageLineCount;
                }
            }

            // which header words to actually count
            boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                    || line.startsWith("Keywords:") || line.startsWith("Summary:"));

            // we're still looking at the header at this point
            // loop through the lines in the file, while the line starts with: " "
            do {

                // get a reader for this specific string ------
                StringReader in = new StringReader(line);

                // ---- count words in header ---------            
                if (countHeader) {
                    //System.out.println( "#### countHeader ################*************" );
                    countWords(analyzer, words, in);
                }

                // iterate to the next string ----
                line = reader.readLine();

            } while (line.startsWith(" "));

            //System.out.println("[break]");

        }

        // now we're done with the header

        //System.out.println("[break-header]");

        //  -------- count words in body ----------
        countWords(analyzer, words, reader);
        reader.close();

        /*        
                for (String word : words.elementSet()) {
                  //encoder.addToVector(word, Math.log(1 + words.count(word)), v);
                 System.out.println( "> " + word + ", " + words.count(word) );
                }        
        */
    }

}

From source file:com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory.java

License:Apache License

/**
 * Processes single line of input into: - target variable - Feature vector
 * /*ww w .j  a  va  2 s  .  c o  m*/
 * @throws Exception
 */
public int processLine(String line, Vector v) throws Exception {

    String[] parts = line.split(this.class_id_split_string);
    if (parts.length < 2) {
        throw new Exception("wtf: line not formed well.");
    }

    String newsgroup_name = parts[0];
    String msg = parts[1];

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average
    // frequency
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    int actual = newsGroups.intern(newsgroup_name);
    // newsGroups.values().contains(arg0)

    // System.out.println( "> newsgroup name: " + newsgroup_name );
    // System.out.println( "> newsgroup id: " + actual );

    Multiset<String> words = ConcurrentHashMultiset.create();
    /*
     * // System.out.println("record: "); for ( int x = 1; x < parts.length; x++
     * ) { //String s = ts.getAttribute(CharTermAttribute.class).toString(); //
     * System.out.print( " " + parts[x] ); String foo = parts[x].trim();
     * System.out.print( " " + foo ); words.add( foo );
     * 
     * } // System.out.println("\nEOR"); System.out.println( "\nwords found: " +
     * (parts.length - 1) ); System.out.println( "words in set: " + words.size()
     * + ", " + words.toString() );
     */

    StringReader in = new StringReader(msg);

    countWords(analyzer, words, in);

    // ----- p.271 -----------
    // Vector v = new RandomAccessSparseVector(FEATURES);

    // original value does nothing in a ContantValueEncoder
    bias.addToVector("", 1, v);

    // original value does nothing in a ContantValueEncoder
    // lines.addToVector("", lineCount / 30, v);

    // original value does nothing in a ContantValueEncoder
    // logLines.addToVector("", Math.log(lineCount + 1), v);

    // now scan through all the words and add them
    // System.out.println( "############### " + words.toArray().length);
    for (String word : words.elementSet()) {
        encoder.addToVector(word, Math.log(1 + words.count(word)), v);
        // System.out.print( words.count(word) + " " );
    }

    // System.out.println("\nEOL\n");

    return actual;
}

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java

License:Apache License

public void testTrainNewsGroups() throws IOException {

    File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/");
    overallCounts = HashMultiset.create();

    long startTime = System.currentTimeMillis();

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);/*  w  w  w . ja  va2 s. com*/
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average frequency 
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    // used to encode the number of lines in a message
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);

    FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines");
    logLines.setTraceDictionary(traceDictionary);

    Dictionary newsGroups = new Dictionary();

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    // bottom of p.269 ------------------------------
    // because OLR expects to get integer class IDs for the target variable during training
    // we need a dictionary to convert the target variable (the newsgroup name)
    // to an integer, which is the newsGroup object
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    // mix up the files, helps training in OLR
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());

    // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    // last line on p.269
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

    Splitter onColon = Splitter.on(":").trimResults();

    int input_file_count = 0;

    // ----- p.270 ------------ "reading and tokenzing the data" ---------
    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));

        input_file_count++;

        // identify newsgroup ----------------
        // convert newsgroup name to unique id
        // -----------------------------------
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();

        // check for line count header -------
        String line = reader.readLine();
        while (line != null && line.length() > 0) {

            // if this is a line that has a line count, let's pull that value out ------
            if (line.startsWith("Lines:")) {
                String count = Iterables.get(onColon.split(line), 1);
                try {
                    lineCount = Integer.parseInt(count);
                    averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
                } catch (NumberFormatException e) {
                    // if anything goes wrong in parse: just use the avg count
                    lineCount = averageLineCount;
                }
            }

            boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                    || line.startsWith("Keywords:") || line.startsWith("Summary:"));

            // loop through the lines in the file, while the line starts with: " "
            do {

                // get a reader for this specific string ------
                StringReader in = new StringReader(line);

                // ---- count words in header ---------            
                if (countHeader) {
                    countWords(analyzer, words, in);
                }

                // iterate to the next string ----
                line = reader.readLine();

            } while (line.startsWith(" "));

        } // while (lines in header) {

        //  -------- count words in body ----------
        countWords(analyzer, words, reader);
        reader.close();

        // ----- p.271 -----------
        Vector v = new RandomAccessSparseVector(FEATURES);

        // original value does nothing in a ContantValueEncoder
        bias.addToVector("", 1, v);

        // original value does nothing in a ContantValueEncoder
        lines.addToVector("", lineCount / 30, v);

        // original value does nothing in a ContantValueEncoder        
        logLines.addToVector("", Math.log(lineCount + 1), v);

        // now scan through all the words and add them
        for (String word : words.elementSet()) {
            encoder.addToVector(word, Math.log(1 + words.count(word)), v);
        }

        //Utils.PrintVectorNonZero(v);

        // calc stats ---------

        double mu = Math.min(k + 1, 200);
        double ll = learningAlgorithm.logLikelihood(actual, v);
        averageLL = averageLL + (ll - averageLL) / mu;

        Vector p = new DenseVector(20);
        learningAlgorithm.classifyFull(p, v);
        int estimated = p.maxValueIndex();

        int correct = (estimated == actual ? 1 : 0);
        averageCorrect = averageCorrect + (correct - averageCorrect) / mu;

        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

        if (k % (bump * scale) == 0) {
            step += 0.25;
            System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng,
                    newsGroups.values().get(estimated));
        }

        learningAlgorithm.close();

        /*    if (k>4) {
              break;
            }
          */

    }

    Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3);

    long endTime = System.currentTimeMillis();

    //System.out.println("That took " + (endTime - startTime) + " milliseconds");
    long duration = (endTime - startTime);

    System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms");

    ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm);
    // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

}

From source file:org.deidentifier.arx.aggregates.classification.MultiClassLogisticRegression.java

License:Apache License

/**
 * Creates a new instance//from  ww w  .j  a  v  a2  s  . c o m
 * @param specification
 * @param config
 */
public MultiClassLogisticRegression(ClassificationDataSpecification specification,
        ARXLogisticRegressionConfiguration config) {

    // Store
    this.config = config;
    this.specification = specification;

    // Prepare classifier
    PriorFunction prior = null;
    switch (config.getPriorFunction()) {
    case ELASTIC_BAND:
        prior = new ElasticBandPrior();
        break;
    case L1:
        prior = new L1();
        break;
    case L2:
        prior = new L2();
        break;
    case UNIFORM:
        prior = new UniformPrior();
        break;
    default:
        throw new IllegalArgumentException("Unknown prior function");
    }
    this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), config.getVectorLength(), prior);

    // Configure
    this.lr.learningRate(config.getLearningRate());
    this.lr.alpha(config.getAlpha());
    this.lr.lambda(config.getLambda());
    this.lr.stepOffset(config.getStepOffset());
    this.lr.decayExponent(config.getDecayExponent());

    // Prepare encoders
    this.interceptEncoder = new ConstantValueEncoder("intercept");
    this.wordEncoder = new StaticWordValueEncoder("feature");

    // Configure
    this.lr.learningRate(1);
    this.lr.alpha(1);
    this.lr.lambda(0.000001);
    this.lr.stepOffset(10000);
    this.lr.decayExponent(0.2);
}

From source file:org.deidentifier.arx.aggregates.classification.MultiClassNaiveBayes.java

License:Apache License

/**
 * Creates a new instance/*  w w  w . ja  va  2s .  c om*/
 * @param interrupt
 * @param specification
 * @param config
 * @param inputHandle
 */
public MultiClassNaiveBayes(WrappedBoolean interrupt, ClassificationDataSpecification specification,
        ClassificationConfigurationNaiveBayes config, DataHandleInternal inputHandle) {

    super(interrupt);

    // Store
    this.config = config;
    this.specification = specification;
    this.inputHandle = inputHandle;

    // Prepare classifier
    this.nb = new NaiveBayes(config.getType() == Type.BERNOULLI ? Model.BERNOULLI : Model.MULTINOMIAL,
            this.specification.classMap.size(), config.getVectorLength(), config.getSigma(), null);

    // Prepare encoders
    this.interceptEncoder = new ConstantValueEncoder("intercept");
    this.wordEncoder = new StaticWordValueEncoder("feature");
}