Example usage for org.apache.mahout.classifier.sgd LogisticModelParameters getCsvRecordFactory

List of usage examples for org.apache.mahout.classifier.sgd LogisticModelParameters getCsvRecordFactory

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd LogisticModelParameters getCsvRecordFactory.

Prototype

public CsvRecordFactory getCsvRecordFactory() 

Source Link

Document

Returns a CsvRecordFactory compatible with this logistic model.

Usage

From source file:com.ml.ira.algos.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;/*from  w  ww  .  java 2s.c o  m*/
            showConfusion = true;
        }

        Auc collector = new Auc();
        LogisticModelParameters lmp;
        if (modelFile.startsWith("hdfs://")) {
            lmp = LogisticModelParameters.loadFrom(new Path(modelFile));
        } else {
            lmp = LogisticModelParameters.loadFrom(new File(modelFile));
        }
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = TrainLogistic.open(inputFile);
        //String line = in.readLine();
        //csv.firstLine(line);
        String line;
        if (fieldNames != null && fieldNames.equalsIgnoreCase("internal")) {
            csv.firstLine(lmp.getFieldNames());
        } else {
            csv.firstLine(in.readLine());
        }
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
        }
    }
}

From source file:guipart.view.GUIOverviewController.java

@FXML
void handleClassifyModel(ActionEvent event) throws IOException {

    if (pathModel != null && pathCSV != null) {

        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();

        BufferedReader in = Utils.open(pathCSV);

        String line = in.readLine();
        csv.firstLine(line);//w  ww  .j ava2 s  .  c  o  m
        line = in.readLine();

        int correct = 0;
        int wrong = 0;
        Boolean booltemp;
        String gender;

        while (line != null) {

            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);
            String[] split = line.split(",");

            double score = lr.classifyFull(v).maxValueIndex();
            if (score == target)
                correct++;
            else
                wrong++;

            System.out.println("Target is: " + target + " Score: " + score);

            booltemp = score != 0;

            if (split[1].contentEquals("1"))
                gender = "male";
            else
                gender = "female";

            Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]),
                    Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]),
                    Integer.parseInt(split[6]), Integer.parseInt(split[3]));

            guiPart.addPerson(temp);

            line = in.readLine();
            collector.add(target, score);

        }
        double posto = ((double) wrong / (double) (correct + wrong)) * 100;
        System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong
                + " Wrong pct: " + posto + "%");
        //PrintWriter output = null;
        Matrix m = collector.confusion();
        //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
        System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t   " + m.get(0, 1) + " "
                + m.get(1, 1) + " ");
        //        m = collector.entropy();
        //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
        textAnalyze2.setText("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t \t   " + m.get(0, 1) + " "
                + m.get(1, 1) + "\n" + "Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: "
                + wrong + " Wrong pct: " + posto + "%");
    } else {

        Dialogs.create().owner(guiPart.getPrimaryStage()).title("Error Dialog")
                .masthead("Look, an Error Dialog").message("One or more files aren't selected").showError();

    }
}

From source file:guipart.view.GUIOverviewController.java

@FXML
void singlClassify(ActionEvent e) throws IOException {

    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();
    csv.firstLine("custID,gender,state,cardholder,balance,numTrans,numIntlTrans,creditLine,fraudRisk");

    String line;//from www. jav  a  2 s. co m

    line = scID.getText();
    line = line.concat("," + scGender.getText());
    line = line.concat("," + scState.getText());
    line = line.concat("," + scCardholders.getText());
    line = line.concat("," + scBalance.getText());
    line = line.concat("," + scTrans.getText());
    line = line.concat("," + scIntlTrans.getText());
    line = line.concat("," + scCreditLine.getText());
    line = line.concat(",0 \n");

    Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
    int target = csv.processLine(line, v);
    String[] split = line.split(",");

    double score = lr.classifyFull(v).maxValueIndex();
    boolean booltemp = score != 0;

    String gender;

    if (split[1].contentEquals("1"))
        gender = "male";
    else
        gender = "female";

    Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]),
            booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]),
            Integer.parseInt(split[3]));

    guiPart.addPerson(temp);
}

From source file:haflow.component.mahout.logistic.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;//from   w  w w. j  a v a2s  . com
            showConfusion = true;
        }

        //PrintWriter output=new PrintWriter(new FileOutputStream(outputFile),true);

        PrintWriter output = new PrintWriter(HdfsUtil.writeHdfs(outputFile), true);
        PrintWriter acc_output = new PrintWriter(HdfsUtil.writeHdfs(accurateFile), true);
        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(HdfsUtil.open(modelFile));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
        String line = in.readLine();
        csv.firstLine(line);
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            acc_output.printf(Locale.ENGLISH, "AUC , %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            acc_output.printf(Locale.ENGLISH, "confusion, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            acc_output.printf(Locale.ENGLISH, "entropy, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
        }
        output.close();
        acc_output.close();
    }
}

From source file:javaapplication3.RunLogistic.java

public static void main(String[] args) throws IOException {
    // TODO code application logic here
    Auc collector = new Auc();

    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();

    BufferedReader in = open(inputFile);

    String line = in.readLine();/*from w  w  w .j  a va2 s . c  o m*/
    csv.firstLine(line);
    line = in.readLine();
    int correct = 0;
    int wrong = 0;
    while (line != null) {
        Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
        int target = csv.processLine(line, v);

        System.out.println(line);
        String[] split = line.split(",");

        double score = lr.classifyFull(v).maxValueIndex();
        if (score == target)
            correct++;
        else
            wrong++;

        System.out.println("Target is: " + target + " Score: " + score);
        line = in.readLine();
        collector.add(target, score);

    }
    double posto = ((double) wrong / (double) (correct + wrong)) * 100;
    System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong
            + " Wrong pct: " + posto + "%");
    //PrintWriter output = null;
    Matrix m = collector.confusion();
    //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
    System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t   " + m.get(0, 1) + " "
            + m.get(1, 1) + " ");
    //        m = collector.entropy();
    //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));

}