List of usage examples for org.apache.mahout.classifier.evaluation Auc Auc
public Auc(double threshold)
From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java
License:Apache License
public static void main(String[] args) throws Exception { List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv")); double heldOutPercentage = 0.10; double biggestScore = 0.0; for (int run = 0; run < 20; run++) { Collections.shuffle(calls); int cutoff = (int) (heldOutPercentage * calls.size()); List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff); List<TelephoneCall> trainData = calls.subList(cutoff, calls.size()); List<TelephoneCall> testUnknownData = new ArrayList<>(); testUnknownData.add(getUnknownTelephoneCall(trainData)); OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2); for (int pass = 0; pass < 20; pass++) { for (TelephoneCall observation : trainData) { lr.train(observation.getTarget(), observation.asVector()); }//from ww w . jav a 2 s. c o m Auc eval = new Auc(0.5); for (TelephoneCall testCall : testAccuracyData) { biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall); } System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run, pass, lr.currentLearningRate(), eval.auc()); for (TelephoneCall testCall : testUnknownData) { final double score = lr.classifyScalar(testCall.asVector()); System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: " + testCall.getFields()); } } } }
From source file:OpioidePrescriberClassification.Driver.java
public static void main(String args[]) throws Exception { List<Opioides> calls = Lists.newArrayList(new Parser("/input1/try.csv")); double heldOutPercentage = 0.10; // for (int run = 0; run < 20; run++) {//ww w.j a v a2 s.c om // Random random = RandomUtils.getRandom(); Collections.shuffle(calls); int cutoff = (int) (heldOutPercentage * calls.size()); List<Opioides> test = calls.subList(0, cutoff); List<Opioides> train = calls.subList(cutoff, calls.size()); OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, Opioides.FEATURES, new L1()) .learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2); // for (int pass = 0; pass < 2 ; pass++) { System.err.println("pass"); for (Opioides observation : train) { lr.train(observation.getTarget(), observation.asVector()); } // if (pass % 2 == 0) { Auc eval = new Auc(0.5); for (Opioides testCall : test) { eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector())); } System.out.printf("%d, %.4f, %.4f\n", 1, lr.currentLearningRate(), eval.auc()); } } } }