Example usage for org.apache.mahout.classifier.evaluation Auc Auc

List of usage examples for org.apache.mahout.classifier.evaluation Auc Auc

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.evaluation Auc Auc.

Prototype

public Auc(double threshold) 

Source Link

Document

Allocates a new data-structure for accumulating information about AUC and a few other accuracy measures.

Usage

From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java

License:Apache License

public static void main(String[] args) throws Exception {
    List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));

    double heldOutPercentage = 0.10;

    double biggestScore = 0.0;

    for (int run = 0; run < 20; run++) {
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff);
        List<TelephoneCall> trainData = calls.subList(cutoff, calls.size());

        List<TelephoneCall> testUnknownData = new ArrayList<>();

        testUnknownData.add(getUnknownTelephoneCall(trainData));

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES,
                new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        for (int pass = 0; pass < 20; pass++) {
            for (TelephoneCall observation : trainData) {
                lr.train(observation.getTarget(), observation.asVector());
            }//from ww w  . jav a 2 s.  c o m
            Auc eval = new Auc(0.5);
            for (TelephoneCall testCall : testAccuracyData) {
                biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall);
            }
            System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run,
                    pass, lr.currentLearningRate(), eval.auc());

            for (TelephoneCall testCall : testUnknownData) {
                final double score = lr.classifyScalar(testCall.asVector());
                System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: "
                        + testCall.getFields());
            }
        }
    }
}

From source file:OpioidePrescriberClassification.Driver.java

public static void main(String args[]) throws Exception {
    List<Opioides> calls = Lists.newArrayList(new Parser("/input1/try.csv"));
    double heldOutPercentage = 0.10;
    //        for (int run = 0; run < 20; run++) 
    {//ww  w.j  a v  a2 s.c om
        //            Random random = RandomUtils.getRandom();
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<Opioides> test = calls.subList(0, cutoff);
        List<Opioides> train = calls.subList(cutoff, calls.size());

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, Opioides.FEATURES, new L1())
                .learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        //            for (int pass = 0; pass < 2 ; pass++)
        {
            System.err.println("pass");
            for (Opioides observation : train) {
                lr.train(observation.getTarget(), observation.asVector());
            }
            //                if (pass % 2 == 0) 
            {
                Auc eval = new Auc(0.5);
                for (Opioides testCall : test) {
                    eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
                }
                System.out.printf("%d, %.4f, %.4f\n", 1, lr.currentLearningRate(), eval.auc());
            }
        }
    }
}