Example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression train

List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression train

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression train.

Prototype

@Override
    public void train(int actual, Vector instance) 

Source Link

Usage

From source file:TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;
        /*read files in dir of inputFile*/
        int fi = 0;//file ID
        File file = new File(inputFile);
        String[] fns = file.list(new FilenameFilter() {
            public boolean accept(File dir, String name) {
                if (name.endsWith(".svm")) {
                    return true;
                } else {
                    return false;
                }/*from w  w  w. j  a v  a 2s.  c o  m*/
            }
        });

        String[] ss = new String[lmp.getNumFeatures() + 1];
        String[] iv = new String[2];
        OnlineLogisticRegression lr = lmp.createRegression();
        while (fi < fns.length) {
            for (int pass = 0; pass < passes; pass++) {
                BufferedReader in = open(inputFile + fns[fi]);
                System.out.println(pass + 1);
                try {
                    // read variable names

                    String line = in.readLine();
                    int lineCount = 1;
                    while (line != null) {
                        // for each new line, get target and predictors
                        Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                        ss = line.split(" ");
                        int targetValue;
                        if (ss[0].startsWith("+"))
                            targetValue = 1;
                        else
                            targetValue = 0;
                        int k = 1;
                        while (k < ss.length) {
                            iv = ss[k].split(":");
                            input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1]));
                            //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount);
                            k++;
                        }
                        input.setQuick(lmp.getNumFeatures() - 1, 1);
                        // check performance while this is still news
                        double logP = lr.logLikelihood(targetValue, input);
                        if (!Double.isInfinite(logP)) {
                            if (samples < 20) {
                                logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                            } else {
                                logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                            }
                            samples++;
                        }
                        double p = lr.classifyScalar(input);
                        if (scores) {
                            output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples,
                                    targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                        }
                        // now update model
                        lr.train(targetValue, input);
                        if ((lineCount) % 1000 == 0)
                            System.out.printf("%d\t", lineCount);
                        line = in.readLine();
                        lineCount++;
                    }
                } finally {
                    Closeables.closeQuietly(in);
                }
                System.out.println();
            }
            fi++;
        }

        FileOutputStream modelOutput = new FileOutputStream(outputFile);
        try {
            saveTo(modelOutput, lr);
        } finally {
            Closeables.closeQuietly(modelOutput);
        }
        /*
              output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
              output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
              String sep = "";
              for (String v : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, 0, csv, v);
                if (weight != 0) {
                  output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                  sep = " + ";
                }
              }
              output.printf("\n");
              model = lr;
              for (int row = 0; row < lr.getBeta().numRows(); row++) {
                for (String key : csv.getTraceDictionary().keySet()) {
                  double weight = predictorWeight(lr, row, csv, key);
                  if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
                  }
                }
                for (int column = 0; column < lr.getBeta().numCols(); column++) {
                  output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
                }
                output.println();
              }*/
    }
}

From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java

License:Apache License

public static void main(String[] args) throws Exception {
    List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));

    double heldOutPercentage = 0.10;

    double biggestScore = 0.0;

    for (int run = 0; run < 20; run++) {
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff);
        List<TelephoneCall> trainData = calls.subList(cutoff, calls.size());

        List<TelephoneCall> testUnknownData = new ArrayList<>();

        testUnknownData.add(getUnknownTelephoneCall(trainData));

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES,
                new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        for (int pass = 0; pass < 20; pass++) {
            for (TelephoneCall observation : trainData) {
                lr.train(observation.getTarget(), observation.asVector());
            }/*from w  ww  . jav  a 2s . co  m*/
            Auc eval = new Auc(0.5);
            for (TelephoneCall testCall : testAccuracyData) {
                biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall);
            }
            System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run,
                    pass, lr.currentLearningRate(), eval.auc());

            for (TelephoneCall testCall : testUnknownData) {
                final double score = lr.classifyScalar(testCall.asVector());
                System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: "
                        + testCall.getFields());
            }
        }
    }
}

From source file:com.cloudera.knittingboar.records.TestTwentyNewsgroupsCustomRecordParseOLRRun.java

License:Apache License

@Test
public void testRecordFactoryOnDatasetShard() throws Exception {
    // TODO a test with assertions is not a test
    // p.270 ----- metrics to track lucene's parsing mechanics, progress,
    // performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    int k = 0;/*from  www  .  j a v a  2 s.  co m*/
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };

    TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory("\t");
    // rec_factory.setClassSplitString("\t");

    JobConf job = new JobConf(defaultConf);

    long block_size = localFs.getDefaultBlockSize(workDir);

    LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB");

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    @SuppressWarnings("resource")
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    FileInputFormat.setInputPaths(job, workDir);

    // try splitting the file in a variety of sizes
    TextInputFormat format = new TextInputFormat();
    format.configure(job);
    Text value = new Text();

    int numSplits = 1;

    InputSplit[] splits = format.getSplits(job, numSplits);

    LOG.info("requested " + numSplits + " splits, splitting: got =        " + splits.length);
    LOG.info("---- debug splits --------- ");
    rec_factory.Debug();
    int total_read = 0;

    for (int x = 0; x < splits.length; x++) {

        LOG.info("> Split [" + x + "]: " + splits[x].getLength());

        int count = 0;
        InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]);
        while (custom_reader.next(value)) {
            Vector v = new RandomAccessSparseVector(TwentyNewsgroupsRecordFactory.FEATURES);
            int actual = rec_factory.processLine(value.toString(), v);

            String ng = rec_factory.GetNewsgroupNameByID(actual);

            // calc stats ---------

            double mu = Math.min(k + 1, 200);
            double ll = learningAlgorithm.logLikelihood(actual, v);
            averageLL = averageLL + (ll - averageLL) / mu;

            Vector p = new DenseVector(20);
            learningAlgorithm.classifyFull(p, v);
            int estimated = p.maxValueIndex();

            int correct = (estimated == actual ? 1 : 0);
            averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
            learningAlgorithm.train(actual, v);
            k++;
            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

            if (k % (bump * scale) == 0) {
                step += 0.25;
                LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, averageLL,
                        averageCorrect * 100, ng, rec_factory.GetNewsgroupNameByID(estimated)));
            }

            learningAlgorithm.close();
            count++;
        }

        LOG.info("read: " + count + " records for split " + x);
        total_read += count;
    } // for each split
    LOG.info("total read across all splits: " + total_read);
    rec_factory.Debug();
}

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java

License:Apache License

public void testTrainNewsGroups() throws IOException {

    File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/");
    overallCounts = HashMultiset.create();

    long startTime = System.currentTimeMillis();

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);// www. j a v a  2s .c o  m
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average frequency 
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    // used to encode the number of lines in a message
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);

    FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines");
    logLines.setTraceDictionary(traceDictionary);

    Dictionary newsGroups = new Dictionary();

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    // bottom of p.269 ------------------------------
    // because OLR expects to get integer class IDs for the target variable during training
    // we need a dictionary to convert the target variable (the newsgroup name)
    // to an integer, which is the newsGroup object
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    // mix up the files, helps training in OLR
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());

    // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    // last line on p.269
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

    Splitter onColon = Splitter.on(":").trimResults();

    int input_file_count = 0;

    // ----- p.270 ------------ "reading and tokenzing the data" ---------
    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));

        input_file_count++;

        // identify newsgroup ----------------
        // convert newsgroup name to unique id
        // -----------------------------------
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();

        // check for line count header -------
        String line = reader.readLine();
        while (line != null && line.length() > 0) {

            // if this is a line that has a line count, let's pull that value out ------
            if (line.startsWith("Lines:")) {
                String count = Iterables.get(onColon.split(line), 1);
                try {
                    lineCount = Integer.parseInt(count);
                    averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
                } catch (NumberFormatException e) {
                    // if anything goes wrong in parse: just use the avg count
                    lineCount = averageLineCount;
                }
            }

            boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                    || line.startsWith("Keywords:") || line.startsWith("Summary:"));

            // loop through the lines in the file, while the line starts with: " "
            do {

                // get a reader for this specific string ------
                StringReader in = new StringReader(line);

                // ---- count words in header ---------            
                if (countHeader) {
                    countWords(analyzer, words, in);
                }

                // iterate to the next string ----
                line = reader.readLine();

            } while (line.startsWith(" "));

        } // while (lines in header) {

        //  -------- count words in body ----------
        countWords(analyzer, words, reader);
        reader.close();

        // ----- p.271 -----------
        Vector v = new RandomAccessSparseVector(FEATURES);

        // original value does nothing in a ContantValueEncoder
        bias.addToVector("", 1, v);

        // original value does nothing in a ContantValueEncoder
        lines.addToVector("", lineCount / 30, v);

        // original value does nothing in a ContantValueEncoder        
        logLines.addToVector("", Math.log(lineCount + 1), v);

        // now scan through all the words and add them
        for (String word : words.elementSet()) {
            encoder.addToVector(word, Math.log(1 + words.count(word)), v);
        }

        //Utils.PrintVectorNonZero(v);

        // calc stats ---------

        double mu = Math.min(k + 1, 200);
        double ll = learningAlgorithm.logLikelihood(actual, v);
        averageLL = averageLL + (ll - averageLL) / mu;

        Vector p = new DenseVector(20);
        learningAlgorithm.classifyFull(p, v);
        int estimated = p.maxValueIndex();

        int correct = (estimated == actual ? 1 : 0);
        averageCorrect = averageCorrect + (correct - averageCorrect) / mu;

        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

        if (k % (bump * scale) == 0) {
            step += 0.25;
            System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng,
                    newsGroups.values().get(estimated));
        }

        learningAlgorithm.close();

        /*    if (k>4) {
              break;
            }
          */

    }

    Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3);

    long endTime = System.currentTimeMillis();

    //System.out.println("That took " + (endTime - startTime) + " milliseconds");
    long duration = (endTime - startTime);

    System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms");

    ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm);
    // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

}

From source file:com.ml.ira.algos.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;

        System.out.println("fieldNames: " + fieldNames);
        long ts = System.currentTimeMillis();
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            System.out.println("at Round: " + pass);
            BufferedReader in = open(inputFile);
            try {
                // read variable names
                String line;//from  www. j  av  a  2 s . c  o  m
                if (fieldNames != null && fieldNames.length() > 0) {
                    csv.firstLine(fieldNames);
                } else {
                    csv.firstLine(in.readLine());
                }
                line = in.readLine();
                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
            output.println("duration: " + (System.currentTimeMillis() - ts));
        }

        if (outputFile.startsWith("hdfs://")) {
            lmp.saveTo(new Path(outputFile));
        } else {
            OutputStream modelOutput = new FileOutputStream(outputFile);
            try {
                lmp.saveTo(modelOutput);
            } finally {
                Closeables.close(modelOutput, false);
            }
        }

        output.println("duration: " + (System.currentTimeMillis() - ts));

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
    }
}

From source file:com.sixgroup.samplerecommender.Point.java

public static void main(String[] args) {

    Map<Point, Integer> points = new HashMap<Point, Integer>();

    points.put(new Point(0, 0), 0);
    points.put(new Point(1, 1), 0);
    points.put(new Point(1, 0), 0);
    points.put(new Point(0, 1), 0);
    points.put(new Point(2, 2), 0);

    points.put(new Point(8, 8), 1);
    points.put(new Point(8, 9), 1);
    points.put(new Point(9, 8), 1);
    points.put(new Point(9, 9), 1);

    OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
    learningAlgo = new OnlineLogisticRegression(2, 3, new L1());
    learningAlgo.lambda(0.1);/*from   w w w.  j a  v a2  s . c o  m*/
    learningAlgo.learningRate(10);

    System.out.println("training model  \n");

    for (Point point : points.keySet()) {

        Vector v = getVector(point);
        System.out.println(point + " belongs to " + points.get(point));
        learningAlgo.train(points.get(point), v);
    }

    learningAlgo.close();

    Vector v = new RandomAccessSparseVector(3);
    v.set(0, 0.5);
    v.set(1, 0.5);
    v.set(2, 1);

    Vector r = learningAlgo.classifyFull(v);
    System.out.println(r);

    System.out.println("ans = ");
    System.out.println("no of categories = " + learningAlgo.numCategories());
    System.out.println("no of features = " + learningAlgo.numFeatures());
    System.out.println("Probability of cluster 0 = " + r.get(0));
    System.out.println("Probability of cluster 1 = " + r.get(1));

}

From source file:com.technobium.MultinomialLogisticRegression.java

License:Apache License

public static void main(String[] args) throws Exception {
    // this test trains a 3-way classifier on the famous Iris dataset.
    // a similar exercise can be accomplished in R using this code:
    //    library(nnet)
    //    correct = rep(0,100)
    //    for (j in 1:100) {
    //      i = order(runif(150))
    //      train = iris[i[1:100],]
    //      test = iris[i[101:150],]
    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
    //    }//from   w  w  w . j a  v  a  2 s .com
    //    hist(correct)
    //
    // Note that depending on the training/test split, performance can be better or worse.
    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
    // of 100%
    //
    // This test uses a deterministic split that is neither outstandingly good nor bad

    RandomUtils.useTestSeed();
    Splitter onComma = Splitter.on(",");

    // read the data
    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);

    // holds features
    List<Vector> data = Lists.newArrayList();

    // holds target variable
    List<Integer> target = Lists.newArrayList();

    // for decoding target values
    Dictionary dict = new Dictionary();

    // for permuting data later
    List<Integer> order = Lists.newArrayList();

    for (String line : raw.subList(1, raw.size())) {
        // order gets a list of indexes
        order.add(order.size());

        // parse the predictor variables
        Vector v = new DenseVector(5);
        v.set(0, 1);
        int i = 1;
        Iterable<String> values = onComma.split(line);
        for (String value : Iterables.limit(values, 4)) {
            v.set(i++, Double.parseDouble(value));
        }
        data.add(v);

        // and the target
        target.add(dict.intern(Iterables.get(values, 4)));
    }

    // randomize the order ... original data has each species all together
    // note that this randomization is deterministic
    Random random = RandomUtils.getRandom();
    Collections.shuffle(order, random);

    // select training and test data
    List<Integer> train = order.subList(0, 100);
    List<Integer> test = order.subList(100, 150);
    logger.warn("Training set = {}", train);
    logger.warn("Test set = {}", test);

    // now train many times and collect information on accuracy each time
    int[] correct = new int[test.size() + 1];
    for (int run = 0; run < 200; run++) {
        OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
        // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
        for (int pass = 0; pass < 30; pass++) {
            Collections.shuffle(train, random);
            for (int k : train) {
                lr.train(target.get(k), data.get(k));
            }
        }

        // check the accuracy on held out data
        int x = 0;
        int[] count = new int[3];
        for (Integer k : test) {
            Vector vt = lr.classifyFull(data.get(k));
            int r = vt.maxValueIndex();
            count[r]++;
            x += r == target.get(k) ? 1 : 0;
        }
        correct[x]++;

        if (run == 199) {

            Vector v = new DenseVector(5);
            v.set(0, 1);
            int i = 1;
            Iterable<String> values = onComma.split("6.0,2.7,5.1,1.6,versicolor");
            for (String value : Iterables.limit(values, 4)) {
                v.set(i++, Double.parseDouble(value));
            }

            Vector vt = lr.classifyFull(v);
            for (String value : dict.values()) {
                System.out.println("target:" + value);
            }
            int t = dict.intern(Iterables.get(values, 4));

            int r = vt.maxValueIndex();
            boolean flag = r == t;
            lr.close();

            Closer closer = Closer.create();

            try {
                FileOutputStream byteArrayOutputStream = closer
                        .register(new FileOutputStream(new File("model.txt")));
                DataOutputStream dataOutputStream = closer
                        .register(new DataOutputStream(byteArrayOutputStream));
                PolymorphicWritable.write(dataOutputStream, lr);
            } finally {
                closer.close();
            }
        }
    }

    // verify we never saw worse than 95% correct,
    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
        System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i],
                100.0 * i / test.size()));
    }
    // nor perfect
    System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]));
}

From source file:de.isabeldrostfromm.sof.Trainer.java

License:Open Source License

@Override
public OnlineLogisticRegression train(ExampleProvider provider) {
    OnlineLogisticRegression logReg = new OnlineLogisticRegression(ModelTargets.STATEVALUES.length,
            Vectoriser.getCardinality(), new L1());

    Multiset<String> set = HashMultiset.create();
    for (Example instance : provider) {
        set.add(instance.getState());/*from w  w  w.j ava2s  .  c  om*/
        logReg.train(ModelTargets.STATES.get(instance.getState()), instance.getVector());
    }

    return logReg;
}

From source file:edu.isi.karma.cleaning.features.RecordClassifier2.java

License:Apache License

@SuppressWarnings({ "deprecation" })
public OnlineLogisticRegression train(HashMap<String, Vector<String>> traindata) throws Exception {
    String csvTrainFile = "./target/tmp/csvtrain.csv";
    Data2Features.Traindata2CSV(traindata, csvTrainFile, rf);
    lmp = new LogisticModelParameters();
    lmp.setTargetVariable("label");
    lmp.setMaxTargetCategories(rf.labels.size());
    lmp.setNumFeatures(rf.getFeatureNames().size());
    List<String> typeList = Lists.newArrayList();
    typeList.add("numeric");
    List<String> predictorList = Lists.newArrayList();
    for (String attr : rf.getFeatureNames()) {
        if (attr.compareTo("lable") != 0) {
            predictorList.add(attr);/*from  w  ww  . j a v  a  2 s. co m*/
        }
    }
    lmp.setTypeMap(predictorList, typeList);
    // lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
    // lmp.setTypeMap(predictorList, typeList);
    lmp.setLambda(1e-4);
    lmp.setLearningRate(50);
    int passes = 100;
    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();
    for (int pass = 0; pass < passes; pass++) {
        BufferedReader in = new BufferedReader(new FileReader(new File(csvTrainFile)));
        ;
        try {
            // read variable names
            csv.firstLine(in.readLine());
            String line = in.readLine();
            while (line != null) {
                // for each new line, get target and predictors
                RandomAccessSparseVector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                int targetValue = csv.processLine(line, input);
                // String label =
                // csv.getTargetCategories().get(lr.classifyFull(input).maxValueIndex());
                // now update model
                lr.train(targetValue, input);
                line = in.readLine();
            }
        } finally {
            Closeables.closeQuietly(in);
        }
    }
    labels = csv.getTargetCategories();
    return lr;

}

From source file:haflow.component.mahout.logistic.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {

        double logPEstimate = 0;
        int samples = 0;

        OutputStream o = HdfsUtil.writeHdfs(inforFile);
        PrintWriter output = new PrintWriter(o, true);

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
            try {
                // read variable names
                csv.firstLine(in.readLine());

                String line = in.readLine();

                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }//from ww w.  ja  v  a  2  s. c o m
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
        }

        //OutputStream modelOutput = new FileOutputStream(outputFile);
        OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile);
        try {
            lmp.saveTo(modelOutput);
        } finally {
            Closeables.close(modelOutput, false);
        }

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
        output.close();
    }

}