OpioidePrescriberClassification.Driver.java Source code

Java tutorial

Introduction

Here is the source code for OpioidePrescriberClassification.Driver.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package OpioidePrescriberClassification;

import com.google.common.collect.Lists;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.common.RandomUtils;

/**
 *
 * @author vishakha
 */
public class Driver {

    public static final int NUM_CATEGORIES = 2;

    public static void main(String args[]) throws Exception {
        List<Opioides> calls = Lists.newArrayList(new Parser("/input1/try.csv"));
        double heldOutPercentage = 0.10;
        //        for (int run = 0; run < 20; run++) 
        {
            //            Random random = RandomUtils.getRandom();
            Collections.shuffle(calls);
            int cutoff = (int) (heldOutPercentage * calls.size());
            List<Opioides> test = calls.subList(0, cutoff);
            List<Opioides> train = calls.subList(cutoff, calls.size());

            OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, Opioides.FEATURES, new L1())
                    .learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

            //            for (int pass = 0; pass < 2 ; pass++)
            {
                System.err.println("pass");
                for (Opioides observation : train) {
                    lr.train(observation.getTarget(), observation.asVector());
                }
                //                if (pass % 2 == 0) 
                {
                    Auc eval = new Auc(0.5);
                    for (Opioides testCall : test) {
                        eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
                    }
                    System.out.printf("%d, %.4f, %.4f\n", 1, lr.currentLearningRate(), eval.auc());
                }
            }
        }
    }

}