Example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression OnlineLogisticRegression

List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression OnlineLogisticRegression

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression OnlineLogisticRegression.

Prototype

public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) 

Source Link

Usage

From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java

License:Apache License

public static void main(String[] args) throws Exception {
    List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));

    double heldOutPercentage = 0.10;

    double biggestScore = 0.0;

    for (int run = 0; run < 20; run++) {
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff);
        List<TelephoneCall> trainData = calls.subList(cutoff, calls.size());

        List<TelephoneCall> testUnknownData = new ArrayList<>();

        testUnknownData.add(getUnknownTelephoneCall(trainData));

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES,
                new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        for (int pass = 0; pass < 20; pass++) {
            for (TelephoneCall observation : trainData) {
                lr.train(observation.getTarget(), observation.asVector());
            }//from  w w w .  j  a v  a 2 s .c o  m
            Auc eval = new Auc(0.5);
            for (TelephoneCall testCall : testAccuracyData) {
                biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall);
            }
            System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run,
                    pass, lr.currentLearningRate(), eval.auc());

            for (TelephoneCall testCall : testUnknownData) {
                final double score = lr.classifyScalar(testCall.asVector());
                System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: "
                        + testCall.getFields());
            }
        }
    }
}

From source file:chapter4.src.logistic.LogisticModelParametersPredict.java

License:Apache License

/**
 * Creates a logistic regression trainer using the parameters collected here.
 *
 * @return The newly allocated OnlineLogisticRegression object
 *//*from   w w w .ja  v  a  2s. com*/
public OnlineLogisticRegression createRegression() {
    if (lr == null) {
        lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
                .lambda(getLambda()).learningRate(getLearningRate()).alpha(1 - 1.0e-3);
    }
    return lr;
}

From source file:com.cloudera.knittingboar.records.TestTwentyNewsgroupsCustomRecordParseOLRRun.java

License:Apache License

@Test
public void testRecordFactoryOnDatasetShard() throws Exception {
    // TODO a test with assertions is not a test
    // p.270 ----- metrics to track lucene's parsing mechanics, progress,
    // performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    int k = 0;/*  w  ww . j ava  2s .c om*/
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };

    TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory("\t");
    // rec_factory.setClassSplitString("\t");

    JobConf job = new JobConf(defaultConf);

    long block_size = localFs.getDefaultBlockSize(workDir);

    LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB");

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    @SuppressWarnings("resource")
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    FileInputFormat.setInputPaths(job, workDir);

    // try splitting the file in a variety of sizes
    TextInputFormat format = new TextInputFormat();
    format.configure(job);
    Text value = new Text();

    int numSplits = 1;

    InputSplit[] splits = format.getSplits(job, numSplits);

    LOG.info("requested " + numSplits + " splits, splitting: got =        " + splits.length);
    LOG.info("---- debug splits --------- ");
    rec_factory.Debug();
    int total_read = 0;

    for (int x = 0; x < splits.length; x++) {

        LOG.info("> Split [" + x + "]: " + splits[x].getLength());

        int count = 0;
        InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]);
        while (custom_reader.next(value)) {
            Vector v = new RandomAccessSparseVector(TwentyNewsgroupsRecordFactory.FEATURES);
            int actual = rec_factory.processLine(value.toString(), v);

            String ng = rec_factory.GetNewsgroupNameByID(actual);

            // calc stats ---------

            double mu = Math.min(k + 1, 200);
            double ll = learningAlgorithm.logLikelihood(actual, v);
            averageLL = averageLL + (ll - averageLL) / mu;

            Vector p = new DenseVector(20);
            learningAlgorithm.classifyFull(p, v);
            int estimated = p.maxValueIndex();

            int correct = (estimated == actual ? 1 : 0);
            averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
            learningAlgorithm.train(actual, v);
            k++;
            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

            if (k % (bump * scale) == 0) {
                step += 0.25;
                LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, averageLL,
                        averageCorrect * 100, ng, rec_factory.GetNewsgroupNameByID(estimated)));
            }

            learningAlgorithm.close();
            count++;
        }

        LOG.info("read: " + count + " records for split " + x);
        total_read += count;
    } // for each split
    LOG.info("total read across all splits: " + total_read);
    rec_factory.Debug();
}

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java

License:Apache License

public void testTrainNewsGroups() throws IOException {

    File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/");
    overallCounts = HashMultiset.create();

    long startTime = System.currentTimeMillis();

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);/* w ww  .jav a 2s  .co  m*/
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average frequency 
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    // used to encode the number of lines in a message
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);

    FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines");
    logLines.setTraceDictionary(traceDictionary);

    Dictionary newsGroups = new Dictionary();

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    // bottom of p.269 ------------------------------
    // because OLR expects to get integer class IDs for the target variable during training
    // we need a dictionary to convert the target variable (the newsgroup name)
    // to an integer, which is the newsGroup object
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    // mix up the files, helps training in OLR
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());

    // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    // last line on p.269
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

    Splitter onColon = Splitter.on(":").trimResults();

    int input_file_count = 0;

    // ----- p.270 ------------ "reading and tokenzing the data" ---------
    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));

        input_file_count++;

        // identify newsgroup ----------------
        // convert newsgroup name to unique id
        // -----------------------------------
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();

        // check for line count header -------
        String line = reader.readLine();
        while (line != null && line.length() > 0) {

            // if this is a line that has a line count, let's pull that value out ------
            if (line.startsWith("Lines:")) {
                String count = Iterables.get(onColon.split(line), 1);
                try {
                    lineCount = Integer.parseInt(count);
                    averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
                } catch (NumberFormatException e) {
                    // if anything goes wrong in parse: just use the avg count
                    lineCount = averageLineCount;
                }
            }

            boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                    || line.startsWith("Keywords:") || line.startsWith("Summary:"));

            // loop through the lines in the file, while the line starts with: " "
            do {

                // get a reader for this specific string ------
                StringReader in = new StringReader(line);

                // ---- count words in header ---------            
                if (countHeader) {
                    countWords(analyzer, words, in);
                }

                // iterate to the next string ----
                line = reader.readLine();

            } while (line.startsWith(" "));

        } // while (lines in header) {

        //  -------- count words in body ----------
        countWords(analyzer, words, reader);
        reader.close();

        // ----- p.271 -----------
        Vector v = new RandomAccessSparseVector(FEATURES);

        // original value does nothing in a ContantValueEncoder
        bias.addToVector("", 1, v);

        // original value does nothing in a ContantValueEncoder
        lines.addToVector("", lineCount / 30, v);

        // original value does nothing in a ContantValueEncoder        
        logLines.addToVector("", Math.log(lineCount + 1), v);

        // now scan through all the words and add them
        for (String word : words.elementSet()) {
            encoder.addToVector(word, Math.log(1 + words.count(word)), v);
        }

        //Utils.PrintVectorNonZero(v);

        // calc stats ---------

        double mu = Math.min(k + 1, 200);
        double ll = learningAlgorithm.logLikelihood(actual, v);
        averageLL = averageLL + (ll - averageLL) / mu;

        Vector p = new DenseVector(20);
        learningAlgorithm.classifyFull(p, v);
        int estimated = p.maxValueIndex();

        int correct = (estimated == actual ? 1 : 0);
        averageCorrect = averageCorrect + (correct - averageCorrect) / mu;

        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

        if (k % (bump * scale) == 0) {
            step += 0.25;
            System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng,
                    newsGroups.values().get(estimated));
        }

        learningAlgorithm.close();

        /*    if (k>4) {
              break;
            }
          */

    }

    Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3);

    long endTime = System.currentTimeMillis();

    //System.out.println("That took " + (endTime - startTime) + " milliseconds");
    long duration = (endTime - startTime);

    System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms");

    ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm);
    // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

}

From source file:com.technobium.MultinomialLogisticRegression.java

License:Apache License

public static void main(String[] args) throws Exception {
    // this test trains a 3-way classifier on the famous Iris dataset.
    // a similar exercise can be accomplished in R using this code:
    //    library(nnet)
    //    correct = rep(0,100)
    //    for (j in 1:100) {
    //      i = order(runif(150))
    //      train = iris[i[1:100],]
    //      test = iris[i[101:150],]
    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
    //    }/*from  w w w.j a v a  2  s  .  co m*/
    //    hist(correct)
    //
    // Note that depending on the training/test split, performance can be better or worse.
    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
    // of 100%
    //
    // This test uses a deterministic split that is neither outstandingly good nor bad

    RandomUtils.useTestSeed();
    Splitter onComma = Splitter.on(",");

    // read the data
    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);

    // holds features
    List<Vector> data = Lists.newArrayList();

    // holds target variable
    List<Integer> target = Lists.newArrayList();

    // for decoding target values
    Dictionary dict = new Dictionary();

    // for permuting data later
    List<Integer> order = Lists.newArrayList();

    for (String line : raw.subList(1, raw.size())) {
        // order gets a list of indexes
        order.add(order.size());

        // parse the predictor variables
        Vector v = new DenseVector(5);
        v.set(0, 1);
        int i = 1;
        Iterable<String> values = onComma.split(line);
        for (String value : Iterables.limit(values, 4)) {
            v.set(i++, Double.parseDouble(value));
        }
        data.add(v);

        // and the target
        target.add(dict.intern(Iterables.get(values, 4)));
    }

    // randomize the order ... original data has each species all together
    // note that this randomization is deterministic
    Random random = RandomUtils.getRandom();
    Collections.shuffle(order, random);

    // select training and test data
    List<Integer> train = order.subList(0, 100);
    List<Integer> test = order.subList(100, 150);
    logger.warn("Training set = {}", train);
    logger.warn("Test set = {}", test);

    // now train many times and collect information on accuracy each time
    int[] correct = new int[test.size() + 1];
    for (int run = 0; run < 200; run++) {
        OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
        // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
        for (int pass = 0; pass < 30; pass++) {
            Collections.shuffle(train, random);
            for (int k : train) {
                lr.train(target.get(k), data.get(k));
            }
        }

        // check the accuracy on held out data
        int x = 0;
        int[] count = new int[3];
        for (Integer k : test) {
            Vector vt = lr.classifyFull(data.get(k));
            int r = vt.maxValueIndex();
            count[r]++;
            x += r == target.get(k) ? 1 : 0;
        }
        correct[x]++;

        if (run == 199) {

            Vector v = new DenseVector(5);
            v.set(0, 1);
            int i = 1;
            Iterable<String> values = onComma.split("6.0,2.7,5.1,1.6,versicolor");
            for (String value : Iterables.limit(values, 4)) {
                v.set(i++, Double.parseDouble(value));
            }

            Vector vt = lr.classifyFull(v);
            for (String value : dict.values()) {
                System.out.println("target:" + value);
            }
            int t = dict.intern(Iterables.get(values, 4));

            int r = vt.maxValueIndex();
            boolean flag = r == t;
            lr.close();

            Closer closer = Closer.create();

            try {
                FileOutputStream byteArrayOutputStream = closer
                        .register(new FileOutputStream(new File("model.txt")));
                DataOutputStream dataOutputStream = closer
                        .register(new DataOutputStream(byteArrayOutputStream));
                PolymorphicWritable.write(dataOutputStream, lr);
            } finally {
                closer.close();
            }
        }
    }

    // verify we never saw worse than 95% correct,
    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
        System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i],
                100.0 * i / test.size()));
    }
    // nor perfect
    System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]));
}

From source file:de.isabeldrostfromm.sof.Trainer.java

License:Open Source License

@Override
public OnlineLogisticRegression train(ExampleProvider provider) {
    OnlineLogisticRegression logReg = new OnlineLogisticRegression(ModelTargets.STATEVALUES.length,
            Vectoriser.getCardinality(), new L1());

    Multiset<String> set = HashMultiset.create();
    for (Example instance : provider) {
        set.add(instance.getState());/*w w w . ja va2 s  .  com*/
        logReg.train(ModelTargets.STATES.get(instance.getState()), instance.getVector());
    }

    return logReg;
}

From source file:opennlp.addons.mahout.OnlineLogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding
    int numberOfOutcomes = indexer.getOutcomeLabels().length;
    int numberOfFeatures = indexer.getPredLabels().length;

    // TODO: Make these parameters configurable ...
    OnlineLogisticRegression pa = new OnlineLogisticRegression(numberOfOutcomes, numberOfFeatures, new L1());

    pa.alpha(1).stepOffset(250).decayExponent(0.9).lambda(3.0e-5).learningRate(3000);

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }// w  w w .j  av a 2 s.co  m

    pa.close();

    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), createPrepMap(indexer));
}

From source file:OpioidePrescriberClassification.Driver.java

public static void main(String args[]) throws Exception {
    List<Opioides> calls = Lists.newArrayList(new Parser("/input1/try.csv"));
    double heldOutPercentage = 0.10;
    //        for (int run = 0; run < 20; run++) 
    {// ww w. ja  v a  2s .c  o  m
        //            Random random = RandomUtils.getRandom();
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<Opioides> test = calls.subList(0, cutoff);
        List<Opioides> train = calls.subList(cutoff, calls.size());

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, Opioides.FEATURES, new L1())
                .learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        //            for (int pass = 0; pass < 2 ; pass++)
        {
            System.err.println("pass");
            for (Opioides observation : train) {
                lr.train(observation.getTarget(), observation.asVector());
            }
            //                if (pass % 2 == 0) 
            {
                Auc eval = new Auc(0.5);
                for (Opioides testCall : test) {
                    eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
                }
                System.out.printf("%d, %.4f, %.4f\n", 1, lr.currentLearningRate(), eval.auc());
            }
        }
    }
}

From source file:org.deidentifier.arx.aggregates.classification.MultiClassLogisticRegression.java

License:Apache License

/**
 * Creates a new instance//from  w w w.  jav  a 2  s.co m
 * @param specification
 * @param config
 */
public MultiClassLogisticRegression(ClassificationDataSpecification specification,
        ARXLogisticRegressionConfiguration config) {

    // Store
    this.config = config;
    this.specification = specification;

    // Prepare classifier
    PriorFunction prior = null;
    switch (config.getPriorFunction()) {
    case ELASTIC_BAND:
        prior = new ElasticBandPrior();
        break;
    case L1:
        prior = new L1();
        break;
    case L2:
        prior = new L2();
        break;
    case UNIFORM:
        prior = new UniformPrior();
        break;
    default:
        throw new IllegalArgumentException("Unknown prior function");
    }
    this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), config.getVectorLength(), prior);

    // Configure
    this.lr.learningRate(config.getLearningRate());
    this.lr.alpha(config.getAlpha());
    this.lr.lambda(config.getLambda());
    this.lr.stepOffset(config.getStepOffset());
    this.lr.decayExponent(config.getDecayExponent());

    // Prepare encoders
    this.interceptEncoder = new ConstantValueEncoder("intercept");
    this.wordEncoder = new StaticWordValueEncoder("feature");

    // Configure
    this.lr.learningRate(1);
    this.lr.alpha(1);
    this.lr.lambda(0.000001);
    this.lr.stepOffset(10000);
    this.lr.decayExponent(0.2);
}

From source file:org.wso2.siddhi.extension.ModelInitializer.java

License:Open Source License

public static OnlineLogisticRegression InitializeLogisticRegression(String modelPath) {
    OnlineLogisticRegression LRmodel = null;
    FileInputStream fileInputStream = null;
    ObjectInputStream objectInputStream = null;
    double[][] modelWeights = null;
    LogisticRegressionModel LRmodelObject;
    try {//from  w w  w  . j av a 2s  .co m
        // get the values for hyper-parameters from model file.
        fileInputStream = new FileInputStream(modelPath);
        objectInputStream = new ObjectInputStream(fileInputStream);
        LRmodelObject = (LogisticRegressionModel) objectInputStream.readObject();
        LRmodel = new OnlineLogisticRegression(LRmodelObject.getNumCategories(), LRmodelObject.getNumFeatures(),
                new L2(1));
        LRmodel.learningRate(LRmodelObject.getLearningRate());
        LRmodel.lambda(LRmodelObject.getLambda());
        LRmodel.alpha(LRmodelObject.getAlpha());
        LRmodel.stepOffset(LRmodelObject.getStepOffset());
        LRmodel.decayExponent(LRmodelObject.getDecayExponent());
        modelWeights = LRmodelObject.getWeights();
        fileInputStream.close();
        objectInputStream.close();
        for (int i = 0; i < modelWeights.length; i++) {
            for (int j = 0; j < modelWeights[0].length; j++) {
                LRmodel.setBeta(i, j, modelWeights[i][j]);
            }
        }
    } catch (Exception e) {
        logger.error("Failed to create a Logistic Regression model from the file \"" + modelPath + "\"", e);
    } finally {
        try {
            fileInputStream.close();
            objectInputStream.close();
        } catch (IOException e) {
            logger.error("Failed to close the model input stream!", e);
        }
    }
    logger.info("Logistic Regression model execution plan successfully intialized for \"" + modelPath
            + "\" model file.");
    return LRmodel;
}