Example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyFull

List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyFull

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyFull.

Prototype

public Vector classifyFull(Vector instance) 

Source Link

Document

Computes and returns a vector containing n scores, where n is numCategories() , given an input vector instance .

Usage

From source file:com.memonews.mahout.sentiment.SentimentModelTester.java

License:Apache License

public void run(final PrintWriter output) throws IOException {

    final File base = new File(inputFile);
    // contains the best model
    final OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(modelFile),
            OnlineLogisticRegression.class);

    final Dictionary newsGroups = new Dictionary();
    final Multiset<String> overallCounts = HashMultiset.create();

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }/* w w  w  .  j  a  va 2  s. c  om*/
    }
    System.out.printf("%d test files\n", files.size());
    final ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
    for (final File file : files) {
        final String ng = file.getParentFile().getName();

        final int actual = newsGroups.intern(ng);
        final SentimentModelHelper helper = new SentimentModelHelper();
        final Vector input = helper.encodeFeatureVector(file, overallCounts);// no
        // leak
        // type
        // ensures
        // this
        // is
        // a
        // normal
        // vector
        final Vector result = classifier.classifyFull(input);
        final int cat = result.maxValueIndex();
        final double score = result.maxValue();
        final double ll = classifier.logLikelihood(actual, input);
        final ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
        ra.addInstance(newsGroups.values().get(actual), cr);

    }
    output.printf("%s\n\n", ra.toString());
}

From source file:com.sixgroup.samplerecommender.Point.java

public static void main(String[] args) {

    Map<Point, Integer> points = new HashMap<Point, Integer>();

    points.put(new Point(0, 0), 0);
    points.put(new Point(1, 1), 0);
    points.put(new Point(1, 0), 0);
    points.put(new Point(0, 1), 0);
    points.put(new Point(2, 2), 0);

    points.put(new Point(8, 8), 1);
    points.put(new Point(8, 9), 1);
    points.put(new Point(9, 8), 1);
    points.put(new Point(9, 9), 1);

    OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
    learningAlgo = new OnlineLogisticRegression(2, 3, new L1());
    learningAlgo.lambda(0.1);/* www  .  jav  a 2s . c o  m*/
    learningAlgo.learningRate(10);

    System.out.println("training model  \n");

    for (Point point : points.keySet()) {

        Vector v = getVector(point);
        System.out.println(point + " belongs to " + points.get(point));
        learningAlgo.train(points.get(point), v);
    }

    learningAlgo.close();

    Vector v = new RandomAccessSparseVector(3);
    v.set(0, 0.5);
    v.set(1, 0.5);
    v.set(2, 1);

    Vector r = learningAlgo.classifyFull(v);
    System.out.println(r);

    System.out.println("ans = ");
    System.out.println("no of categories = " + learningAlgo.numCategories());
    System.out.println("no of features = " + learningAlgo.numFeatures());
    System.out.println("Probability of cluster 0 = " + r.get(0));
    System.out.println("Probability of cluster 1 = " + r.get(1));

}

From source file:com.technobium.MultinomialLogisticRegression.java

License:Apache License

public static void main(String[] args) throws Exception {
    // this test trains a 3-way classifier on the famous Iris dataset.
    // a similar exercise can be accomplished in R using this code:
    //    library(nnet)
    //    correct = rep(0,100)
    //    for (j in 1:100) {
    //      i = order(runif(150))
    //      train = iris[i[1:100],]
    //      test = iris[i[101:150],]
    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
    //    }/* w  w  w  .j av  a2 s  . co  m*/
    //    hist(correct)
    //
    // Note that depending on the training/test split, performance can be better or worse.
    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
    // of 100%
    //
    // This test uses a deterministic split that is neither outstandingly good nor bad

    RandomUtils.useTestSeed();
    Splitter onComma = Splitter.on(",");

    // read the data
    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);

    // holds features
    List<Vector> data = Lists.newArrayList();

    // holds target variable
    List<Integer> target = Lists.newArrayList();

    // for decoding target values
    Dictionary dict = new Dictionary();

    // for permuting data later
    List<Integer> order = Lists.newArrayList();

    for (String line : raw.subList(1, raw.size())) {
        // order gets a list of indexes
        order.add(order.size());

        // parse the predictor variables
        Vector v = new DenseVector(5);
        v.set(0, 1);
        int i = 1;
        Iterable<String> values = onComma.split(line);
        for (String value : Iterables.limit(values, 4)) {
            v.set(i++, Double.parseDouble(value));
        }
        data.add(v);

        // and the target
        target.add(dict.intern(Iterables.get(values, 4)));
    }

    // randomize the order ... original data has each species all together
    // note that this randomization is deterministic
    Random random = RandomUtils.getRandom();
    Collections.shuffle(order, random);

    // select training and test data
    List<Integer> train = order.subList(0, 100);
    List<Integer> test = order.subList(100, 150);
    logger.warn("Training set = {}", train);
    logger.warn("Test set = {}", test);

    // now train many times and collect information on accuracy each time
    int[] correct = new int[test.size() + 1];
    for (int run = 0; run < 200; run++) {
        OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
        // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
        for (int pass = 0; pass < 30; pass++) {
            Collections.shuffle(train, random);
            for (int k : train) {
                lr.train(target.get(k), data.get(k));
            }
        }

        // check the accuracy on held out data
        int x = 0;
        int[] count = new int[3];
        for (Integer k : test) {
            Vector vt = lr.classifyFull(data.get(k));
            int r = vt.maxValueIndex();
            count[r]++;
            x += r == target.get(k) ? 1 : 0;
        }
        correct[x]++;

        if (run == 199) {

            Vector v = new DenseVector(5);
            v.set(0, 1);
            int i = 1;
            Iterable<String> values = onComma.split("6.0,2.7,5.1,1.6,versicolor");
            for (String value : Iterables.limit(values, 4)) {
                v.set(i++, Double.parseDouble(value));
            }

            Vector vt = lr.classifyFull(v);
            for (String value : dict.values()) {
                System.out.println("target:" + value);
            }
            int t = dict.intern(Iterables.get(values, 4));

            int r = vt.maxValueIndex();
            boolean flag = r == t;
            lr.close();

            Closer closer = Closer.create();

            try {
                FileOutputStream byteArrayOutputStream = closer
                        .register(new FileOutputStream(new File("model.txt")));
                DataOutputStream dataOutputStream = closer
                        .register(new DataOutputStream(byteArrayOutputStream));
                PolymorphicWritable.write(dataOutputStream, lr);
            } finally {
                closer.close();
            }
        }
    }

    // verify we never saw worse than 95% correct,
    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
        System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i],
                100.0 * i / test.size()));
    }
    // nor perfect
    System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]));
}

From source file:guipart.view.GUIOverviewController.java

@FXML
void handleClassifyModel(ActionEvent event) throws IOException {

    if (pathModel != null && pathCSV != null) {

        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();

        BufferedReader in = Utils.open(pathCSV);

        String line = in.readLine();
        csv.firstLine(line);/*from  w w  w . j ava  2s.  c  o  m*/
        line = in.readLine();

        int correct = 0;
        int wrong = 0;
        Boolean booltemp;
        String gender;

        while (line != null) {

            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);
            String[] split = line.split(",");

            double score = lr.classifyFull(v).maxValueIndex();
            if (score == target)
                correct++;
            else
                wrong++;

            System.out.println("Target is: " + target + " Score: " + score);

            booltemp = score != 0;

            if (split[1].contentEquals("1"))
                gender = "male";
            else
                gender = "female";

            Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]),
                    Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]),
                    Integer.parseInt(split[6]), Integer.parseInt(split[3]));

            guiPart.addPerson(temp);

            line = in.readLine();
            collector.add(target, score);

        }
        double posto = ((double) wrong / (double) (correct + wrong)) * 100;
        System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong
                + " Wrong pct: " + posto + "%");
        //PrintWriter output = null;
        Matrix m = collector.confusion();
        //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
        System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t   " + m.get(0, 1) + " "
                + m.get(1, 1) + " ");
        //        m = collector.entropy();
        //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
        textAnalyze2.setText("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t \t   " + m.get(0, 1) + " "
                + m.get(1, 1) + "\n" + "Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: "
                + wrong + " Wrong pct: " + posto + "%");
    } else {

        Dialogs.create().owner(guiPart.getPrimaryStage()).title("Error Dialog")
                .masthead("Look, an Error Dialog").message("One or more files aren't selected").showError();

    }
}

From source file:guipart.view.GUIOverviewController.java

@FXML
void singlClassify(ActionEvent e) throws IOException {

    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();
    csv.firstLine("custID,gender,state,cardholder,balance,numTrans,numIntlTrans,creditLine,fraudRisk");

    String line;//from w w  w .j  av  a  2 s .c om

    line = scID.getText();
    line = line.concat("," + scGender.getText());
    line = line.concat("," + scState.getText());
    line = line.concat("," + scCardholders.getText());
    line = line.concat("," + scBalance.getText());
    line = line.concat("," + scTrans.getText());
    line = line.concat("," + scIntlTrans.getText());
    line = line.concat("," + scCreditLine.getText());
    line = line.concat(",0 \n");

    Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
    int target = csv.processLine(line, v);
    String[] split = line.split(",");

    double score = lr.classifyFull(v).maxValueIndex();
    boolean booltemp = score != 0;

    String gender;

    if (split[1].contentEquals("1"))
        gender = "male";
    else
        gender = "female";

    Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]),
            booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]),
            Integer.parseInt(split[3]));

    guiPart.addPerson(temp);
}

From source file:javaapplication3.RunLogistic.java

public static void main(String[] args) throws IOException {
    // TODO code application logic here
    Auc collector = new Auc();

    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();

    BufferedReader in = open(inputFile);

    String line = in.readLine();//  w  w  w  .  j a  v  a  2 s.  c  om
    csv.firstLine(line);
    line = in.readLine();
    int correct = 0;
    int wrong = 0;
    while (line != null) {
        Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
        int target = csv.processLine(line, v);

        System.out.println(line);
        String[] split = line.split(",");

        double score = lr.classifyFull(v).maxValueIndex();
        if (score == target)
            correct++;
        else
            wrong++;

        System.out.println("Target is: " + target + " Score: " + score);
        line = in.readLine();
        collector.add(target, score);

    }
    double posto = ((double) wrong / (double) (correct + wrong)) * 100;
    System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong
            + " Wrong pct: " + posto + "%");
    //PrintWriter output = null;
    Matrix m = collector.confusion();
    //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
    System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t   " + m.get(0, 1) + " "
            + m.get(1, 1) + " ");
    //        m = collector.entropy();
    //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));

}

From source file:technobium.OnlineLogisticRegressionTest.java

License:Apache License

public static void main(String[] args) throws Exception {
    // this test trains a 3-way classifier on the famous Iris dataset.
    // a similar exercise can be accomplished in R using this code:
    //    library(nnet)
    //    correct = rep(0,100)
    //    for (j in 1:100) {
    //      i = order(runif(150))
    //      train = iris[i[1:100],]
    //      test = iris[i[101:150],]
    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
    //    }// ww  w.j  ava2  s .c  om
    //    hist(correct)
    //
    // Note that depending on the training/test split, performance can be better or worse.
    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
    // of 100%
    //
    // This test uses a deterministic split that is neither outstandingly good nor bad

    RandomUtils.useTestSeed();
    Splitter onComma = Splitter.on(",");

    // read the data
    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);

    // holds features
    List<Vector> data = Lists.newArrayList();

    // holds target variable
    List<Integer> target = Lists.newArrayList();

    // for decoding target values
    Dictionary dict = new Dictionary();

    // for permuting data later
    List<Integer> order = Lists.newArrayList();

    for (String line : raw.subList(1, raw.size())) {
        // order gets a list of indexes
        order.add(order.size());

        // parse the predictor variables
        Vector v = new DenseVector(5);
        v.set(0, 1);
        int i = 1;
        Iterable<String> values = onComma.split(line);
        for (String value : Iterables.limit(values, 4)) {
            v.set(i++, Double.parseDouble(value));
        }
        data.add(v);

        // and the target
        target.add(dict.intern(Iterables.get(values, 4)));
    }

    // randomize the order ... original data has each species all together
    // note that this randomization is deterministic
    Random random = RandomUtils.getRandom();
    Collections.shuffle(order, random);

    // select training and test data
    List<Integer> train = order.subList(0, 100);
    List<Integer> test = order.subList(100, 150);
    logger.warn("Training set = {}", train);
    logger.warn("Test set = {}", test);

    // now train many times and collect information on accuracy each time
    int[] correct = new int[test.size() + 1];
    for (int run = 0; run < 200; run++) {
        OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
        // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
        for (int pass = 0; pass < 30; pass++) {
            Collections.shuffle(train, random);
            for (int k : train) {
                lr.train(target.get(k), data.get(k));
            }
        }

        // check the accuracy on held out data
        int x = 0;
        int[] count = new int[3];
        for (Integer k : test) {
            Vector vt = lr.classifyFull(data.get(k));
            int r = vt.maxValueIndex();
            count[r]++;
            x += r == target.get(k) ? 1 : 0;
        }
        correct[x]++;
    }

    // verify we never saw worse than 95% correct,
    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
        System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i],
                100.0 * i / test.size()));
    }
    // nor perfect
    System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]));
}