Example usage for org.apache.mahout.classifier.sgd L1 L1

List of usage examples for org.apache.mahout.classifier.sgd L1 L1

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd L1 L1.

Prototype

L1

Source Link

Usage

From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java

License:Apache License

public static void main(String[] args) throws Exception {
    List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));

    double heldOutPercentage = 0.10;

    double biggestScore = 0.0;

    for (int run = 0; run < 20; run++) {
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff);
        List<TelephoneCall> trainData = calls.subList(cutoff, calls.size());

        List<TelephoneCall> testUnknownData = new ArrayList<>();

        testUnknownData.add(getUnknownTelephoneCall(trainData));

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES,
                new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        for (int pass = 0; pass < 20; pass++) {
            for (TelephoneCall observation : trainData) {
                lr.train(observation.getTarget(), observation.asVector());
            }/*from www .  j a v  a  2s .  c  o m*/
            Auc eval = new Auc(0.5);
            for (TelephoneCall testCall : testAccuracyData) {
                biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall);
            }
            System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run,
                    pass, lr.currentLearningRate(), eval.auc());

            for (TelephoneCall testCall : testUnknownData) {
                final double score = lr.classifyScalar(testCall.asVector());
                System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: "
                        + testCall.getFields());
            }
        }
    }
}

From source file:chapter4.src.logistic.LogisticModelParametersPredict.java

License:Apache License

/**
 * Creates a logistic regression trainer using the parameters collected here.
 *
 * @return The newly allocated OnlineLogisticRegression object
 *//*  w  w w  . ja  v a2 s .  co m*/
public OnlineLogisticRegression createRegression() {
    if (lr == null) {
        lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
                .lambda(getLambda()).learningRate(getLearningRate()).alpha(1 - 1.0e-3);
    }
    return lr;
}

From source file:com.cloudera.knittingboar.records.TestTwentyNewsgroupsCustomRecordParseOLRRun.java

License:Apache License

@Test
public void testRecordFactoryOnDatasetShard() throws Exception {
    // TODO a test with assertions is not a test
    // p.270 ----- metrics to track lucene's parsing mechanics, progress,
    // performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    int k = 0;//from ww  w .j av  a  2  s.co m
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };

    TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory("\t");
    // rec_factory.setClassSplitString("\t");

    JobConf job = new JobConf(defaultConf);

    long block_size = localFs.getDefaultBlockSize(workDir);

    LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB");

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    @SuppressWarnings("resource")
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    FileInputFormat.setInputPaths(job, workDir);

    // try splitting the file in a variety of sizes
    TextInputFormat format = new TextInputFormat();
    format.configure(job);
    Text value = new Text();

    int numSplits = 1;

    InputSplit[] splits = format.getSplits(job, numSplits);

    LOG.info("requested " + numSplits + " splits, splitting: got =        " + splits.length);
    LOG.info("---- debug splits --------- ");
    rec_factory.Debug();
    int total_read = 0;

    for (int x = 0; x < splits.length; x++) {

        LOG.info("> Split [" + x + "]: " + splits[x].getLength());

        int count = 0;
        InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]);
        while (custom_reader.next(value)) {
            Vector v = new RandomAccessSparseVector(TwentyNewsgroupsRecordFactory.FEATURES);
            int actual = rec_factory.processLine(value.toString(), v);

            String ng = rec_factory.GetNewsgroupNameByID(actual);

            // calc stats ---------

            double mu = Math.min(k + 1, 200);
            double ll = learningAlgorithm.logLikelihood(actual, v);
            averageLL = averageLL + (ll - averageLL) / mu;

            Vector p = new DenseVector(20);
            learningAlgorithm.classifyFull(p, v);
            int estimated = p.maxValueIndex();

            int correct = (estimated == actual ? 1 : 0);
            averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
            learningAlgorithm.train(actual, v);
            k++;
            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

            if (k % (bump * scale) == 0) {
                step += 0.25;
                LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, averageLL,
                        averageCorrect * 100, ng, rec_factory.GetNewsgroupNameByID(estimated)));
            }

            learningAlgorithm.close();
            count++;
        }

        LOG.info("read: " + count + " records for split " + x);
        total_read += count;
    } // for each split
    LOG.info("total read across all splits: " + total_read);
    rec_factory.Debug();
}

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java

License:Apache License

public void testTrainNewsGroups() throws IOException {

    File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/");
    overallCounts = HashMultiset.create();

    long startTime = System.currentTimeMillis();

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);/*from www.  j  ava 2 s. co m*/
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average frequency 
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    // used to encode the number of lines in a message
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);

    FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines");
    logLines.setTraceDictionary(traceDictionary);

    Dictionary newsGroups = new Dictionary();

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    // bottom of p.269 ------------------------------
    // because OLR expects to get integer class IDs for the target variable during training
    // we need a dictionary to convert the target variable (the newsgroup name)
    // to an integer, which is the newsGroup object
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    // mix up the files, helps training in OLR
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());

    // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    // last line on p.269
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

    Splitter onColon = Splitter.on(":").trimResults();

    int input_file_count = 0;

    // ----- p.270 ------------ "reading and tokenzing the data" ---------
    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));

        input_file_count++;

        // identify newsgroup ----------------
        // convert newsgroup name to unique id
        // -----------------------------------
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();

        // check for line count header -------
        String line = reader.readLine();
        while (line != null && line.length() > 0) {

            // if this is a line that has a line count, let's pull that value out ------
            if (line.startsWith("Lines:")) {
                String count = Iterables.get(onColon.split(line), 1);
                try {
                    lineCount = Integer.parseInt(count);
                    averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
                } catch (NumberFormatException e) {
                    // if anything goes wrong in parse: just use the avg count
                    lineCount = averageLineCount;
                }
            }

            boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                    || line.startsWith("Keywords:") || line.startsWith("Summary:"));

            // loop through the lines in the file, while the line starts with: " "
            do {

                // get a reader for this specific string ------
                StringReader in = new StringReader(line);

                // ---- count words in header ---------            
                if (countHeader) {
                    countWords(analyzer, words, in);
                }

                // iterate to the next string ----
                line = reader.readLine();

            } while (line.startsWith(" "));

        } // while (lines in header) {

        //  -------- count words in body ----------
        countWords(analyzer, words, reader);
        reader.close();

        // ----- p.271 -----------
        Vector v = new RandomAccessSparseVector(FEATURES);

        // original value does nothing in a ContantValueEncoder
        bias.addToVector("", 1, v);

        // original value does nothing in a ContantValueEncoder
        lines.addToVector("", lineCount / 30, v);

        // original value does nothing in a ContantValueEncoder        
        logLines.addToVector("", Math.log(lineCount + 1), v);

        // now scan through all the words and add them
        for (String word : words.elementSet()) {
            encoder.addToVector(word, Math.log(1 + words.count(word)), v);
        }

        //Utils.PrintVectorNonZero(v);

        // calc stats ---------

        double mu = Math.min(k + 1, 200);
        double ll = learningAlgorithm.logLikelihood(actual, v);
        averageLL = averageLL + (ll - averageLL) / mu;

        Vector p = new DenseVector(20);
        learningAlgorithm.classifyFull(p, v);
        int estimated = p.maxValueIndex();

        int correct = (estimated == actual ? 1 : 0);
        averageCorrect = averageCorrect + (correct - averageCorrect) / mu;

        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

        if (k % (bump * scale) == 0) {
            step += 0.25;
            System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng,
                    newsGroups.values().get(estimated));
        }

        learningAlgorithm.close();

        /*    if (k>4) {
              break;
            }
          */

    }

    Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3);

    long endTime = System.currentTimeMillis();

    //System.out.println("That took " + (endTime - startTime) + " milliseconds");
    long duration = (endTime - startTime);

    System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms");

    ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm);
    // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

}

From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java

License:Apache License

public void testCreateLR() {

    int categories = 2;
    int numFeatures = 5;
    double lambda = 1.0e-4;
    double learning_rate = 50;

    ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression(categories, numFeatures,
            new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

    assertEquals(plr.getLambda(), 1.0e-4);
}

From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java

License:Apache License

public void testTrainMechanics() {

    int categories = 2;
    int numFeatures = 5;
    double lambda = 1.0e-4;
    double learning_rate = 10;

    ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression(categories, numFeatures,
            new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

    Vector input = new RandomAccessSparseVector(numFeatures);

    for (int x = 0; x < numFeatures; x++) {

        input.set(x, x);/*from   w  ww .ja  v  a 2  s.c o  m*/

    }

    plr.train(0, input);

    plr.train(0, input);

    plr.train(0, input);

}

From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java

License:Apache License

public void testPOLRInternalBuffers() {

    System.out.println("testPOLRInternalBuffers --------------");

    int categories = 2;
    int numFeatures = 5;
    double lambda = 1.0e-4;
    double learning_rate = 10;

    ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>();

    for (int s = 0; s < 1; s++) {

        Vector input = new RandomAccessSparseVector(numFeatures);

        for (int x = 0; x < numFeatures; x++) {

            input.set(x, x);//ww  w  . ja  v a2  s . c  om

        }

        trainingSet_0.add(input);

    } // for

    ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression(categories, numFeatures,
            new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

    System.out.println("Beta: ");
    //Utils.PrintVectorNonZero(plr_agent_0.getBeta().getRow(0));
    Utils.PrintVectorNonZero(plr_agent_0.getBeta().viewRow(0));

    System.out.println("\nGamma: ");
    //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0));
    Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));

    plr_agent_0.train(0, trainingSet_0.get(0));

    System.out.println("Beta: ");
    //Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().getRow(0));
    Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().viewRow(0));

    System.out.println("\nGamma: ");
    //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0));
    Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));

}

From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java

License:Apache License

public void testLocalGradientFlush() {

    System.out.println("\n\n\ntestLocalGradientFlush --------------");

    int categories = 2;
    int numFeatures = 5;
    double lambda = 1.0e-4;
    double learning_rate = 10;

    ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>();

    for (int s = 0; s < 1; s++) {

        Vector input = new RandomAccessSparseVector(numFeatures);

        for (int x = 0; x < numFeatures; x++) {

            input.set(x, x);//from w  ww .j  a v a 2 s. c o m

        }

        trainingSet_0.add(input);

    } // for

    ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression(categories, numFeatures,
            new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3);

    plr_agent_0.train(0, trainingSet_0.get(0));

    System.out.println("\nGamma: ");
    Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));

    plr_agent_0.FlushGamma();

    System.out.println("Flushing Gamma ...... ");

    System.out.println("\nGamma: ");
    Utils.PrintVector(plr_agent_0.gamma.getMatrix().viewRow(0));

    for (int x = 0; x < numFeatures; x++) {

        assertEquals(plr_agent_0.gamma.getMatrix().get(0, x), 0.0);

    }

}

From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java

License:Apache License

public static void main(final String[] args) throws IOException {
    final File base = new File(args[0]);
    final String modelPath = args.length > 1 ? args[1] : "target/model";

    final Multiset<String> overallCounts = HashMultiset.create();

    final Dictionary newsGroups = new Dictionary();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setProbes(2);//  w  w w .j a v a2  s  . c  om
    final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2,
            SentimentModelHelper.FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    final SGDInfo info = new SGDInfo();

    int k = 0;

    for (final File file : files) {
        final String ng = file.getParentFile().getName();
        final int actual = newsGroups.intern(ng);

        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        learningAlgorithm.train(actual, v);

        k++;
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();

        SGDHelper.analyzeState(info, 0, k, best);
    }
    learningAlgorithm.close();
    SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts);
    System.out.println("exiting main");

    ModelSerializer.writeBinary(modelPath,
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

    final List<Integer> counts = Lists.newArrayList();
    System.out.printf("Word counts\n");
    for (final String count : overallCounts.elementSet()) {
        counts.add(overallCounts.count(count));
    }
    Collections.sort(counts, Ordering.natural().reverse());
    k = 0;
    for (final Integer count : counts) {
        System.out.printf("%d\t%d\n", k, count);
        k++;
        if (k > 1000) {
            break;
        }
    }
}

From source file:com.ml.ira.algos.AdaptiveLogisticModelParameters.java

License:Apache License

private static PriorFunction createPrior(String cmd, double priorOption) {
    if (cmd == null) {
        return null;
    }//from   w  w  w. j a v a2  s  .co m
    if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new L1();
    }
    if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new L2();
    }
    if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new UniformPrior();
    }
    if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new TPrior(priorOption);
    }
    if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new ElasticBandPrior(priorOption);
    }

    return null;
}