Example usage for org.apache.mahout.classifier.sgd CsvRecordFactory getTargetCategories

List of usage examples for org.apache.mahout.classifier.sgd CsvRecordFactory getTargetCategories

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd CsvRecordFactory getTargetCategories.

Prototype

@Override
    public List<String> getTargetCategories() 

Source Link

Usage

From source file:edu.isi.karma.cleaning.features.RecordClassifier2.java

License:Apache License

@SuppressWarnings({ "deprecation" })
public OnlineLogisticRegression train(HashMap<String, Vector<String>> traindata) throws Exception {
    String csvTrainFile = "./target/tmp/csvtrain.csv";
    Data2Features.Traindata2CSV(traindata, csvTrainFile, rf);
    lmp = new LogisticModelParameters();
    lmp.setTargetVariable("label");
    lmp.setMaxTargetCategories(rf.labels.size());
    lmp.setNumFeatures(rf.getFeatureNames().size());
    List<String> typeList = Lists.newArrayList();
    typeList.add("numeric");
    List<String> predictorList = Lists.newArrayList();
    for (String attr : rf.getFeatureNames()) {
        if (attr.compareTo("lable") != 0) {
            predictorList.add(attr);/*w w  w .  j a  v  a 2 s. c  o  m*/
        }
    }
    lmp.setTypeMap(predictorList, typeList);
    // lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
    // lmp.setTypeMap(predictorList, typeList);
    lmp.setLambda(1e-4);
    lmp.setLearningRate(50);
    int passes = 100;
    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    OnlineLogisticRegression lr = lmp.createRegression();
    for (int pass = 0; pass < passes; pass++) {
        BufferedReader in = new BufferedReader(new FileReader(new File(csvTrainFile)));
        ;
        try {
            // read variable names
            csv.firstLine(in.readLine());
            String line = in.readLine();
            while (line != null) {
                // for each new line, get target and predictors
                RandomAccessSparseVector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                int targetValue = csv.processLine(line, input);
                // String label =
                // csv.getTargetCategories().get(lr.classifyFull(input).maxValueIndex());
                // now update model
                lr.train(targetValue, input);
                line = in.readLine();
            }
        } finally {
            Closeables.closeQuietly(in);
        }
    }
    labels = csv.getTargetCategories();
    return lr;

}