Example usage for org.apache.mahout.classifier.sgd CsvRecordFactory firstLine

List of usage examples for org.apache.mahout.classifier.sgd CsvRecordFactory firstLine

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd CsvRecordFactory firstLine.

Prototype

@Override
public void firstLine(String line) 

Source Link

Document

Processes the first line of a file (which should contain the variable names).

Usage

From source file:com.ml.ira.algos.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;/*from  w  w  w.  ja  v  a  2 s  .  c  o m*/
            showConfusion = true;
        }

        Auc collector = new Auc();
        LogisticModelParameters lmp;
        if (modelFile.startsWith("hdfs://")) {
            lmp = LogisticModelParameters.loadFrom(new Path(modelFile));
        } else {
            lmp = LogisticModelParameters.loadFrom(new File(modelFile));
        }
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = TrainLogistic.open(inputFile);
        //String line = in.readLine();
        //csv.firstLine(line);
        String line;
        if (fieldNames != null && fieldNames.equalsIgnoreCase("internal")) {
            csv.firstLine(lmp.getFieldNames());
        } else {
            csv.firstLine(in.readLine());
        }
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
        }
    }
}

From source file:com.ml.ira.algos.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;

        System.out.println("fieldNames: " + fieldNames);
        long ts = System.currentTimeMillis();
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            System.out.println("at Round: " + pass);
            BufferedReader in = open(inputFile);
            try {
                // read variable names
                String line;/*w  ww  .j a v  a 2 s  .  com*/
                if (fieldNames != null && fieldNames.length() > 0) {
                    csv.firstLine(fieldNames);
                } else {
                    csv.firstLine(in.readLine());
                }
                line = in.readLine();
                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
            output.println("duration: " + (System.currentTimeMillis() - ts));
        }

        if (outputFile.startsWith("hdfs://")) {
            lmp.saveTo(new Path(outputFile));
        } else {
            OutputStream modelOutput = new FileOutputStream(outputFile);
            try {
                lmp.saveTo(modelOutput);
            } finally {
                Closeables.close(modelOutput, false);
            }
        }

        output.println("duration: " + (System.currentTimeMillis() - ts));

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
    }
}

From source file:edu.isi.karma.cleaning.features.RecordClassifier2.java

License:Apache License

@SuppressWarnings({ "deprecation" })
public OnlineLogisticRegression train(HashMap<String, Vector<String>> traindata) throws Exception {
    String csvTrainFile = "./target/tmp/csvtrain.csv";
    Data2Features.Traindata2CSV(traindata, csvTrainFile, rf);
    lmp = new LogisticModelParameters();
    lmp.setTargetVariable("label");
    lmp.setMaxTargetCategories(rf.labels.size());
    lmp.setNumFeatures(rf.getFeatureNames().size());
    List<String> typeList = Lists.newArrayList();
    typeList.add("numeric");
    List<String> predictorList = Lists.newArrayList();
    for (String attr : rf.getFeatureNames()) {
        if (attr.compareTo("lable") != 0) {
            predictorList.add(attr);// w  w  w .  jav  a  2s. com
        }
    }
    lmp.setTypeMap(predictorList, typeList);
    // lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
    // lmp.setTypeMap(predictorList, typeList);
    lmp.setLambda(1e-4);
    lmp.setLearningRate(50);
    int passes = 100;
    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();
    for (int pass = 0; pass < passes; pass++) {
        BufferedReader in = new BufferedReader(new FileReader(new File(csvTrainFile)));
        ;
        try {
            // read variable names
            csv.firstLine(in.readLine());
            String line = in.readLine();
            while (line != null) {
                // for each new line, get target and predictors
                RandomAccessSparseVector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                int targetValue = csv.processLine(line, input);
                // String label =
                // csv.getTargetCategories().get(lr.classifyFull(input).maxValueIndex());
                // now update model
                lr.train(targetValue, input);
                line = in.readLine();
            }
        } finally {
            Closeables.closeQuietly(in);
        }
    }
    labels = csv.getTargetCategories();
    return lr;

}

From source file:guipart.view.GUIOverviewController.java

@FXML
void handleClassifyModel(ActionEvent event) throws IOException {

    if (pathModel != null && pathCSV != null) {

        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();

        BufferedReader in = Utils.open(pathCSV);

        String line = in.readLine();
        csv.firstLine(line);
        line = in.readLine();/*from   w  w w .  j av a 2 s.c  o  m*/

        int correct = 0;
        int wrong = 0;
        Boolean booltemp;
        String gender;

        while (line != null) {

            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);
            String[] split = line.split(",");

            double score = lr.classifyFull(v).maxValueIndex();
            if (score == target)
                correct++;
            else
                wrong++;

            System.out.println("Target is: " + target + " Score: " + score);

            booltemp = score != 0;

            if (split[1].contentEquals("1"))
                gender = "male";
            else
                gender = "female";

            Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]),
                    Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]),
                    Integer.parseInt(split[6]), Integer.parseInt(split[3]));

            guiPart.addPerson(temp);

            line = in.readLine();
            collector.add(target, score);

        }
        double posto = ((double) wrong / (double) (correct + wrong)) * 100;
        System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong
                + " Wrong pct: " + posto + "%");
        //PrintWriter output = null;
        Matrix m = collector.confusion();
        //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
        System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t   " + m.get(0, 1) + " "
                + m.get(1, 1) + " ");
        //        m = collector.entropy();
        //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
        textAnalyze2.setText("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t \t   " + m.get(0, 1) + " "
                + m.get(1, 1) + "\n" + "Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: "
                + wrong + " Wrong pct: " + posto + "%");
    } else {

        Dialogs.create().owner(guiPart.getPrimaryStage()).title("Error Dialog")
                .masthead("Look, an Error Dialog").message("One or more files aren't selected").showError();

    }
}

From source file:guipart.view.GUIOverviewController.java

@FXML
void singlClassify(ActionEvent e) throws IOException {

    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();
    csv.firstLine("custID,gender,state,cardholder,balance,numTrans,numIntlTrans,creditLine,fraudRisk");

    String line;//from ww  w  .j  a  v a2s. c  o m

    line = scID.getText();
    line = line.concat("," + scGender.getText());
    line = line.concat("," + scState.getText());
    line = line.concat("," + scCardholders.getText());
    line = line.concat("," + scBalance.getText());
    line = line.concat("," + scTrans.getText());
    line = line.concat("," + scIntlTrans.getText());
    line = line.concat("," + scCreditLine.getText());
    line = line.concat(",0 \n");

    Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
    int target = csv.processLine(line, v);
    String[] split = line.split(",");

    double score = lr.classifyFull(v).maxValueIndex();
    boolean booltemp = score != 0;

    String gender;

    if (split[1].contentEquals("1"))
        gender = "male";
    else
        gender = "female";

    Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]),
            booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]),
            Integer.parseInt(split[3]));

    guiPart.addPerson(temp);
}

From source file:haflow.component.mahout.logistic.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;/*from  w  w  w  .j av  a 2s. co m*/
            showConfusion = true;
        }

        //PrintWriter output=new PrintWriter(new FileOutputStream(outputFile),true);

        PrintWriter output = new PrintWriter(HdfsUtil.writeHdfs(outputFile), true);
        PrintWriter acc_output = new PrintWriter(HdfsUtil.writeHdfs(accurateFile), true);
        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(HdfsUtil.open(modelFile));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
        String line = in.readLine();
        csv.firstLine(line);
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            acc_output.printf(Locale.ENGLISH, "AUC , %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            acc_output.printf(Locale.ENGLISH, "confusion, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            acc_output.printf(Locale.ENGLISH, "entropy, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
        }
        output.close();
        acc_output.close();
    }
}

From source file:haflow.component.mahout.logistic.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {

        double logPEstimate = 0;
        int samples = 0;

        OutputStream o = HdfsUtil.writeHdfs(inforFile);
        PrintWriter output = new PrintWriter(o, true);

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
            try {
                // read variable names
                csv.firstLine(in.readLine());

                String line = in.readLine();

                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }//from  ww  w  .  ja v a  2 s  .  c  om
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
        }

        //OutputStream modelOutput = new FileOutputStream(outputFile);
        OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile);
        try {
            lmp.saveTo(modelOutput);
        } finally {
            Closeables.close(modelOutput, false);
        }

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
        output.close();
    }

}

From source file:javaapplication3.RunLogistic.java

public static void main(String[] args) throws IOException {
    // TODO code application logic here
    Auc collector = new Auc();

    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();

    BufferedReader in = open(inputFile);

    String line = in.readLine();//from   ww w  . j  a  va2  s. c o  m
    csv.firstLine(line);
    line = in.readLine();
    int correct = 0;
    int wrong = 0;
    while (line != null) {
        Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
        int target = csv.processLine(line, v);

        System.out.println(line);
        String[] split = line.split(",");

        double score = lr.classifyFull(v).maxValueIndex();
        if (score == target)
            correct++;
        else
            wrong++;

        System.out.println("Target is: " + target + " Score: " + score);
        line = in.readLine();
        collector.add(target, score);

    }
    double posto = ((double) wrong / (double) (correct + wrong)) * 100;
    System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong
            + " Wrong pct: " + posto + "%");
    //PrintWriter output = null;
    Matrix m = collector.confusion();
    //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
    System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t   " + m.get(0, 1) + " "
            + m.get(1, 1) + " ");
    //        m = collector.entropy();
    //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));

}