LibsvmTest.java :  » Natural-Language-Processing » MinorThird » edu » cmu » minorthird » classify » Java Open Source

Java Open Source » Natural Language Processing » MinorThird 
MinorThird » edu » cmu » minorthird » classify » LibsvmTest.java
package edu.cmu.minorthird.classify;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URI;
import java.net.URL;
import java.util.Random;

import junit.framework.Test;
import junit.framework.TestSuite;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import edu.cmu.minorthird.classify.algorithms.svm.SVMClassifier;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.experiments.Evaluation;

/**
 *
 * This class is responsible for testing Libsvm wrappers
 *
 * @author ksteppe
 */
public class LibsvmTest extends AbstractClassificationChecks{

  Logger log=Logger.getLogger(this.getClass());

  private static final String trainFile="edu/cmu/minorthird/classify/testcases/a1a.dat";
  //private static final String model="modelFile.dat";
  private static final String testFile="edu/cmu/minorthird/classify/testcases/a1a.t.dat";

  /**
   * Standard test class constructior for LibsvmTest
   * @param name Name of the test
   */
  public LibsvmTest(String name){
    super(name);
  }

  /**
   * Convinence constructior for LibsvmTest
   */
  public LibsvmTest(){
    super("LibsvmTest");
  }

  /**
   * setUp to run before each test
   */
  protected void setUp(){
    org.apache.log4j.Logger.getRootLogger().removeAllAppenders();
    org.apache.log4j.BasicConfigurator.configure();
    log.setLevel(Level.DEBUG);
    super.setCheckStandards(false);
    //TODO add initializations if needed
  }

  /**
   * clean up to run after each test
   */
  protected void tearDown(){
    //TODO clean up resources if needed
  }

  /**
   * use wrapper on the provided data, should get same results
   * as the direct
   */
  public void testWrapper(){
    try{
      //get datasets
      URL url=this.getClass().getClassLoader().getResource(trainFile);
      Dataset trainData=DatasetLoader.loadSVMStyle(new File(new URI(url.toExternalForm())));
      url=this.getClass().getClassLoader().getResource(testFile);
      Dataset testData=DatasetLoader.loadSVMStyle(new File(new URI(url.toExternalForm())));

      //send expectations to checkClassifyText()
      double[] expect=
        new double[]{
          0.13769470404984424,
          0.6011745705024105,
          0.6934812760055479,
          // should be infinity if not calculating probabilities
          // 1.3132616875183545,
          Double.POSITIVE_INFINITY,
        };
      super.setCheckStandards(true);
      super.checkClassify(new SVMLearner(),trainData,testData,expect);
    }catch(Exception e){
      e.printStackTrace();
    }
  }

  /**
   * run the svm wrapper on the sample data
   */
  public void testSampleData(){

    double[] refs=new double[]{
        0.0,0.0,0.0,0.0,0.0,0.0,0.0, //0-6 are 0
        1.0,1.0, //7-8 are 1
        1.3132616875182228,1.0,1.0,1.0, //10-12 are 1
        1.0 //13 is 1
    }; 

    super.checkClassify(new SVMLearner(),SampleDatasets.toyTrain(),
        SampleDatasets.toyTest(),refs);
  }

  /**
   *  Test a full cycle of training, testing, saving (serializing), loading, and testing again.
   **/
  public void testSerialization(){
    try{
      // Create a classifier using the SVMLearner and the toyTrain dataset
      SVMLearner l=new SVMLearner();
      Classifier c1=
        new DatasetClassifierTeacher(SampleDatasets.toyTrain()).train(l);
      File tempFile=File.createTempFile("SVMTest","classifier");

      // Evaluate it immediately saving the stats
      Evaluation e1=new Evaluation(SampleDatasets.toyTrain().getSchema());
      e1.extend(c1,SampleDatasets.toyTest(),1);
      double[] stats1=new double[4];
      stats1[0]=e1.errorRate();
      stats1[1]=e1.averagePrecision();
      stats1[2]=e1.maxF1();
      stats1[3]=e1.averageLogLoss();

      // Serialize the classifier to disk
      //ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("SVMTest.classifier")));
      ObjectOutputStream out=
        new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(
            tempFile)));
      out.writeObject(c1);
      out.flush();
      out.close();

      // Load it back in.
      //ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream("SVMTest.classifier")));
      ObjectInputStream in=
        new ObjectInputStream(new BufferedInputStream(new FileInputStream(
            tempFile)));
      Classifier c2=(Classifier)in.readObject();
      in.close();

      // Evaluate again saving the stats
      Evaluation e2=new Evaluation(SampleDatasets.toyTrain().getSchema());
      e2.extend(c2,SampleDatasets.toyTest(),1);
      //double[] stats2 = e2.summaryStatistics();
      double[] stats2=new double[4];
      stats2[0]=e2.errorRate();
      stats2[1]=e2.averagePrecision();
      stats2[2]=e2.maxF1();
      stats2[3]=e2.averageLogLoss();

      // Only use the basic stats for now because some of the advanced stats
      //  come back as NaN for both datasets and the check stats method can't
      //  handle NaN's
      log.info("using Standard stats only (4 of them)");

      // Compare the stats produced from each run to make sure they are identical
      checkStats(stats1,stats2);

      // Remove the temporary classifier file
      tempFile.delete();
    }catch(Exception e){
      e.printStackTrace();
    }
  }

  /**
   * Test the MultiClass classification stuff.  There are two cases to consider: with and without 
   * calculation of probability estimates.  The libsvm documentation states that these two cases
   * may return different classifications.  These tests simply classify a sample dataset and 
   * check the stats produced against expected values.
   */
  public void testMultiClassClassification(){
    Dataset trainSet=SampleDatasets.makeToy3ClassData(new Random(12345),100);
    Dataset testSet=SampleDatasets.makeToy3ClassData(new Random(67890),100);

    try{
      // Create a classifier using the SVMLearner and the toyTrain dataset
      SVMLearner l=new SVMLearner();

      // First run the test without probability estimates
      l.setDoProbabilityEstimates(false);
      SVMClassifier c1=
        (SVMClassifier)(new DatasetClassifierTeacher(trainSet)
        .train(l));
      Evaluation e1=new Evaluation(trainSet.getSchema());
      e1.extend(c1,testSet,1);
      double[] stats1=new double[4];
      stats1[0]=e1.errorRate();
      stats1[1]=e1.averagePrecision();
      stats1[2]=e1.maxF1();
      stats1[3]=e1.averageLogLoss();

      System.out.println("Error Rate: "+e1.errorRate());
      System.out.println("Avg Precision: "+e1.averagePrecision());
      System.out.println("Max F1: "+e1.maxF1());
      System.out.println("Avg Log Loss: "+e1.averageLogLoss());

      // The stats we expect the classification to return.
      double[] expected=new double[4];
      expected[0]=0.07;
      expected[1]=-1.0;
      expected[2]=-1.0;
      expected[3]=Double.POSITIVE_INFINITY;

      // Compare the stats produced from the run without probability estimates with expected values;
      checkStats(stats1,expected);

      //
      // On a small dataset libsvm may return vastly different stats from run to run so for now 
      //  this test is commented out.
      //
      // Now do it with probability estimates
      l.setDoProbabilityEstimates(true);
      SVMClassifier c2=
        (SVMClassifier)(new DatasetClassifierTeacher(trainSet)
        .train(l));
      Evaluation e2=new Evaluation(trainSet.getSchema());
      e2.extend(c2,testSet,1);
      double[] stats2=new double[4];
      stats2[0]=e2.errorRate();
      stats2[1]=e2.averagePrecision();
      stats2[2]=e2.maxF1();
      stats2[3]=e2.averageLogLoss();

      System.out.println("Error Rate2: "+e2.errorRate());
      System.out.println("Avg Precision2: "+e2.averagePrecision());
      System.out.println("Max F1-2: "+e2.maxF1());
      System.out.println("Avg Log Loss2: "+e2.averageLogLoss());

      // The stats we expect the classification to return.
      expected[0]=0.08;
      expected[1]=-1.0;
      expected[2]=-1.0;
      expected[3]=1.194999431381944;

      // Compare the stats produced from the run with probability estimates with expected values.  The libsvm
      //  package doesn't always come up with the "exact" same stats, but they are within 0.05 of each other
      //  so update the delta acordingly.
      setDelta(0.05);
      checkStats(stats2,expected);
    }catch(Exception e){
      e.printStackTrace();
    }
  }

  /**
   * Creates a TestSuite from all testXXX methods
   * @return TestSuite
   */
  public static Test suite(){
    return new TestSuite(LibsvmTest.class);
  }

  /**
   * Run the full suite of tests with text output
   * @param args - unused
   */
  public static void main(String args[]){
    junit.textui.TestRunner.run(suite());
  }

//  // Crap from svm_predict.java
//  private double[] predict(BufferedReader input,DataOutputStream output,
//      svm_model model) throws IOException{
//    int correct=0;
//    int total=0;
//    double error=0;
//    double sumv=0,sumy=0,sumvv=0,sumyy=0,sumvy=0;
//
//    while(true){
//      String line=input.readLine();
//      if(line==null)
//        break;
//
//      StringTokenizer st=new StringTokenizer(line," \t\n\r\f:");
//
//      double target=atof(st.nextToken());
//      int m=st.countTokens()/2;
//      svm_node[] x=new svm_node[m];
//      for(int j=0;j<m;j++){
//        x[j]=new svm_node();
//        x[j].index=atoi(st.nextToken());
//        x[j].value=atof(st.nextToken());
//      }
//      double v=svm.svm_predict(model,x);
//      if(v==target)
//        ++correct;
//      error+=(v-target)*(v-target);
//      sumv+=v;
//      sumy+=target;
//      sumvv+=v*v;
//      sumyy+=target*target;
//      sumvy+=v*target;
//      ++total;
//
////      output.writeBytes(v+"\n");
//    }
//    log.debug("Accuracy = "+(double)correct/total*100+"% ("+correct+"/"+total+
//    ") (classification)\n");
//    log.debug("Mean squared error = "+error/total+" (regression)\n");
//    log.debug("Squared correlation coefficient = "+
//        ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
//        ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+" (regression)\n");
//
//    double[] rvalues=new double[3];
//    rvalues[0]=(double)correct/(double)total;
//    rvalues[1]=error/(double)total;
//    rvalues[2]=
//      ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
//      ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy));
//
//    return rvalues;
//
//  }

//  private double[] prediction(String argv[]) throws IOException{
//    if(argv.length!=3){
//      System.err.print("usage: svm-predict test_file model_file output_file\n");
//      System.exit(1);
//    }
//
//    BufferedReader input=new BufferedReader(new FileReader(argv[0]));
//    DataOutputStream output=new DataOutputStream(new FileOutputStream(argv[2]));
//    svm_model model=svm.svm_load_model(argv[1]);
//    return predict(input,output,model);
//  }

//  private static double atof(String s){
//    return Double.valueOf(s).doubleValue();
//  }
//
//  private static int atoi(String s){
//    return Integer.parseInt(s);
//  }

}
java2s.com  | Contact Us | Privacy Policy
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.