List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyFull
public Vector classifyFull(Vector instance)
From source file:com.memonews.mahout.sentiment.SentimentModelTester.java
License:Apache License
public void run(final PrintWriter output) throws IOException { final File base = new File(inputFile); // contains the best model final OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class); final Dictionary newsGroups = new Dictionary(); final Multiset<String> overallCounts = HashMultiset.create(); final List<File> files = Lists.newArrayList(); for (final File newsgroup : base.listFiles()) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); }/* w w w . j a va 2 s. c om*/ } System.out.printf("%d test files\n", files.size()); final ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT"); for (final File file : files) { final String ng = file.getParentFile().getName(); final int actual = newsGroups.intern(ng); final SentimentModelHelper helper = new SentimentModelHelper(); final Vector input = helper.encodeFeatureVector(file, overallCounts);// no // leak // type // ensures // this // is // a // normal // vector final Vector result = classifier.classifyFull(input); final int cat = result.maxValueIndex(); final double score = result.maxValue(); final double ll = classifier.logLikelihood(actual, input); final ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll); ra.addInstance(newsGroups.values().get(actual), cr); } output.printf("%s\n\n", ra.toString()); }
From source file:com.sixgroup.samplerecommender.Point.java
public static void main(String[] args) { Map<Point, Integer> points = new HashMap<Point, Integer>(); points.put(new Point(0, 0), 0); points.put(new Point(1, 1), 0); points.put(new Point(1, 0), 0); points.put(new Point(0, 1), 0); points.put(new Point(2, 2), 0); points.put(new Point(8, 8), 1); points.put(new Point(8, 9), 1); points.put(new Point(9, 8), 1); points.put(new Point(9, 9), 1); OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(); learningAlgo = new OnlineLogisticRegression(2, 3, new L1()); learningAlgo.lambda(0.1);/* www . jav a 2s . c o m*/ learningAlgo.learningRate(10); System.out.println("training model \n"); for (Point point : points.keySet()) { Vector v = getVector(point); System.out.println(point + " belongs to " + points.get(point)); learningAlgo.train(points.get(point), v); } learningAlgo.close(); Vector v = new RandomAccessSparseVector(3); v.set(0, 0.5); v.set(1, 0.5); v.set(2, 1); Vector r = learningAlgo.classifyFull(v); System.out.println(r); System.out.println("ans = "); System.out.println("no of categories = " + learningAlgo.numCategories()); System.out.println("no of features = " + learningAlgo.numFeatures()); System.out.println("Probability of cluster 0 = " + r.get(0)); System.out.println("Probability of cluster 1 = " + r.get(1)); }
From source file:com.technobium.MultinomialLogisticRegression.java
License:Apache License
public static void main(String[] args) throws Exception { // this test trains a 3-way classifier on the famous Iris dataset. // a similar exercise can be accomplished in R using this code: // library(nnet) // correct = rep(0,100) // for (j in 1:100) { // i = order(runif(150)) // train = iris[i[1:100],] // test = iris[i[101:150],] // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) // correct[j] = mean(predict(m, newdata=test) == test$Species) // }/* w w w .j av a2 s . co m*/ // hist(correct) // // Note that depending on the training/test split, performance can be better or worse. // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy // of 100% // // This test uses a deterministic split that is neither outstandingly good nor bad RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List<Vector> data = Lists.newArrayList(); // holds target variable List<Integer> target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List<Integer> order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test); // now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { Vector vt = lr.classifyFull(data.get(k)); int r = vt.maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; if (run == 199) { Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split("6.0,2.7,5.1,1.6,versicolor"); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } Vector vt = lr.classifyFull(v); for (String value : dict.values()) { System.out.println("target:" + value); } int t = dict.intern(Iterables.get(values, 4)); int r = vt.maxValueIndex(); boolean flag = r == t; lr.close(); Closer closer = Closer.create(); try { FileOutputStream byteArrayOutputStream = closer .register(new FileOutputStream(new File("model.txt"))); DataOutputStream dataOutputStream = closer .register(new DataOutputStream(byteArrayOutputStream)); PolymorphicWritable.write(dataOutputStream, lr); } finally { closer.close(); } } } // verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size())); } // nor perfect System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1])); }
From source file:guipart.view.GUIOverviewController.java
@FXML void handleClassifyModel(ActionEvent event) throws IOException { if (pathModel != null && pathCSV != null) { Auc collector = new Auc(); LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = Utils.open(pathCSV); String line = in.readLine(); csv.firstLine(line);/*from w w w . j ava 2s. c o m*/ line = in.readLine(); int correct = 0; int wrong = 0; Boolean booltemp; String gender; while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); String[] split = line.split(","); double score = lr.classifyFull(v).maxValueIndex(); if (score == target) correct++; else wrong++; System.out.println("Target is: " + target + " Score: " + score); booltemp = score != 0; if (split[1].contentEquals("1")) gender = "male"; else gender = "female"; Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]), Integer.parseInt(split[3])); guiPart.addPerson(temp); line = in.readLine(); collector.add(target, score); } double posto = ((double) wrong / (double) (correct + wrong)) * 100; System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong + " Wrong pct: " + posto + "%"); //PrintWriter output = null; Matrix m = collector.confusion(); //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t " + m.get(0, 1) + " " + m.get(1, 1) + " "); // m = collector.entropy(); //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); textAnalyze2.setText("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t \t " + m.get(0, 1) + " " + m.get(1, 1) + "\n" + "Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong + " Wrong pct: " + posto + "%"); } else { Dialogs.create().owner(guiPart.getPrimaryStage()).title("Error Dialog") .masthead("Look, an Error Dialog").message("One or more files aren't selected").showError(); } }
From source file:guipart.view.GUIOverviewController.java
@FXML void singlClassify(ActionEvent e) throws IOException { LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); csv.firstLine("custID,gender,state,cardholder,balance,numTrans,numIntlTrans,creditLine,fraudRisk"); String line;//from w w w .j av a 2 s .c om line = scID.getText(); line = line.concat("," + scGender.getText()); line = line.concat("," + scState.getText()); line = line.concat("," + scCardholders.getText()); line = line.concat("," + scBalance.getText()); line = line.concat("," + scTrans.getText()); line = line.concat("," + scIntlTrans.getText()); line = line.concat("," + scCreditLine.getText()); line = line.concat(",0 \n"); Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); String[] split = line.split(","); double score = lr.classifyFull(v).maxValueIndex(); boolean booltemp = score != 0; String gender; if (split[1].contentEquals("1")) gender = "male"; else gender = "female"; Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]), Integer.parseInt(split[3])); guiPart.addPerson(temp); }
From source file:javaapplication3.RunLogistic.java
public static void main(String[] args) throws IOException { // TODO code application logic here Auc collector = new Auc(); LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = open(inputFile); String line = in.readLine();// w w w . j a v a 2 s. c om csv.firstLine(line); line = in.readLine(); int correct = 0; int wrong = 0; while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); System.out.println(line); String[] split = line.split(","); double score = lr.classifyFull(v).maxValueIndex(); if (score == target) correct++; else wrong++; System.out.println("Target is: " + target + " Score: " + score); line = in.readLine(); collector.add(target, score); } double posto = ((double) wrong / (double) (correct + wrong)) * 100; System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong + " Wrong pct: " + posto + "%"); //PrintWriter output = null; Matrix m = collector.confusion(); //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t " + m.get(0, 1) + " " + m.get(1, 1) + " "); // m = collector.entropy(); //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); }
From source file:technobium.OnlineLogisticRegressionTest.java
License:Apache License
public static void main(String[] args) throws Exception { // this test trains a 3-way classifier on the famous Iris dataset. // a similar exercise can be accomplished in R using this code: // library(nnet) // correct = rep(0,100) // for (j in 1:100) { // i = order(runif(150)) // train = iris[i[1:100],] // test = iris[i[101:150],] // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) // correct[j] = mean(predict(m, newdata=test) == test$Species) // }// ww w.j ava2 s .c om // hist(correct) // // Note that depending on the training/test split, performance can be better or worse. // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy // of 100% // // This test uses a deterministic split that is neither outstandingly good nor bad RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List<Vector> data = Lists.newArrayList(); // holds target variable List<Integer> target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List<Integer> order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test); // now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { Vector vt = lr.classifyFull(data.get(k)); int r = vt.maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; } // verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size())); } // nor perfect System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1])); }