edu.isi.karma.cleaning.features.RecordClassifier2.java Source code

Java tutorial

Introduction

Here is the source code for edu.isi.karma.cleaning.features.RecordClassifier2.java

Source

/*******************************************************************************
 * Copyright 2012 University of Southern California
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *    http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
 * This code was developed by the Information Integration Group as part 
 * of the Karma project at the Information Sciences Institute of the 
 * University of Southern California.  For more information, publications, 
 * and related projects, please see: http://www.isi.edu/integration
 ******************************************************************************/

package edu.isi.karma.cleaning.features;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Vector;

import org.apache.mahout.classifier.sgd.CsvRecordFactory;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.RecordFactory;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;

import edu.isi.karma.cleaning.PartitionClassifierType;

public class RecordClassifier2 implements PartitionClassifierType {
    HashMap<String, Vector<String>> trainData = new HashMap<String, Vector<String>>();
    RecordFeatureSet rf = new RecordFeatureSet();
    OnlineLogisticRegression cf;
    List<String> labels = new ArrayList<String>();
    LogisticModelParameters lmp;

    public RecordClassifier2() {

    }

    public OnlineLogisticRegression train(HashMap<String, Vector<String>> traindata) throws Exception {
        String csvTrainFile = "./target/tmp/csvtrain.csv";
        Data2Features.Traindata2CSV(traindata, csvTrainFile, rf);
        lmp = new LogisticModelParameters();
        lmp.setTargetVariable("label");
        lmp.setMaxTargetCategories(rf.labels.size());
        lmp.setNumFeatures(rf.getFeatureNames().size());
        List<String> typeList = Lists.newArrayList();
        typeList.add("numeric");
        List<String> predictorList = Lists.newArrayList();
        for (String attr : rf.getFeatureNames()) {
            if (attr.compareTo("lable") != 0) {
                predictorList.add(attr);
            }
        }
        lmp.setTypeMap(predictorList, typeList);
        // lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
        // lmp.setTypeMap(predictorList, typeList);
        lmp.setLambda(1e-4);
        lmp.setLearningRate(50);
        int passes = 100;
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            BufferedReader in = new BufferedReader(new FileReader(new File(csvTrainFile)));
            ;
            try {
                // read variable names
                csv.firstLine(in.readLine());
                String line = in.readLine();
                while (line != null) {
                    // for each new line, get target and predictors
                    RandomAccessSparseVector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);
                    String label = csv.getTargetCategories().get(lr.classifyFull(input).maxValueIndex());
                    // now update model
                    lr.train(targetValue, input);
                    line = in.readLine();
                }
            } finally {
                Closeables.closeQuietly(in);
            }
        }
        labels = csv.getTargetCategories();
        return lr;

    }

    private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv,
            String predictor) {
        double weight = 0;
        for (Integer column : csv.getTraceDictionary().get(predictor)) {
            weight += lr.getBeta().get(row, column);
        }
        return weight;
    }

    public String Classify(String instance) {
        Collection<Feature> cfeat = rf.computeFeatures(instance, "");
        Feature[] x = cfeat.toArray(new Feature[cfeat.size()]);
        // row.add(f.getName());
        RandomAccessSparseVector row = new RandomAccessSparseVector(x.length);
        String line = "";
        for (int k = 0; k < cfeat.size(); k++) {
            line += x[k].getScore() + ",";
        }
        line += "label"; // dummy class label for testing
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        csv.processLine(line, row);
        DenseVector dvec = (DenseVector) this.cf.classifyFull(row);
        String label = labels.get(dvec.maxValueIndex());
        return label;
    }

    @Override
    public void addTrainingData(String value, String label) {
        if (trainData.containsKey(label)) {
            trainData.get(label).add(value);
        } else {
            Vector<String> vsStrings = new Vector<String>();
            vsStrings.add(value);
            trainData.put(label, vsStrings);
        }
    }

    @Override
    public String learnClassifer() {
        try {
            this.cf = this.train(trainData);
        } catch (Exception e) {
            System.out.println("");
        }
        return this.cf.toString();
    }

    @Override
    public String getLabel(String value) {
        try {
            String label = this.Classify(value);
            if (label.length() > 0)
                return label;
            else {
                return "null_in_classification";
            }
        } catch (Exception e) {
            return "null_in_classification";
            // TODO: handle exception
        }
    }

    public static void main(String[] args) {
        try {
            HashMap<String, Vector<String>> trainData = new HashMap<String, Vector<String>>();
            Vector<String> test = new Vector<String>();
            Vector<String> par1 = new Vector<String>();
            par1.add("1286 adams blvd");
            par1.add("3711 catalina st");
            // par1.add("11 w 37th pl, los angeles");
            Vector<String> par2 = new Vector<String>();
            par2.add("1142 37st");
            // par2.add("1 jefferson st");
            Vector<String> par3 = new Vector<String>();
            par3.add("710 27");
            trainData.put("c1", par1);
            trainData.put("c2", par2);
            trainData.put("c3", par3);
            test.add("2353 portland st");

            RecordClassifier2 rc = new RecordClassifier2();
            for (String key : trainData.keySet()) {
                for (String value : trainData.get(key)) {
                    rc.addTrainingData(value, key);
                }
            }
            rc.learnClassifer();
            System.out.println(rc.Classify(test.get(0)));
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}