TweakedLearner.java :  » Natural-Language-Processing » MinorThird » edu » cmu » minorthird » classify » Java Open Source

Java Open Source » Natural Language Processing » MinorThird 
MinorThird » edu » cmu » minorthird » classify » TweakedLearner.java
package edu.cmu.minorthird.classify;

import java.awt.BorderLayout;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

import org.apache.log4j.Logger;

import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.experiments.Evaluation.Matrix;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;

/** 
 * A Tweaked Learner, with an optimization of the precision vs. recall
 * 
 * @author Giora Unger
 * Created on May 19, 2005
 *
 * A learner whose score was optimized according to an F_beta() function,
 * for a given beta. This optimization is used to fine-tune the precision
 * vs. recall for the underlying classification algorithm.
 * Values of beta<1.0 favor precision over recall, while values of
 * beta>1.0 favor recall over precision. beta=1.0 grants equal weight 
 * to both precision and recall.  
 *  
 * <p>Reference: Jason D. M. Rennie,
 * <i>Derivation of the F-Measure</i>,
 * http://people.csail.mit.edu/jrennie/writing/fmeasure.pdf
 */
public class TweakedLearner extends BatchBinaryClassifierLearner{

  // inner learner given to this class during construction
  private BinaryClassifierLearner innerLearner;

  // the beta according to which F_beta is to be maximized
  private double beta;

  // dataset given 
  private Dataset m_dataset;

  // dataset schema
  private ExampleSchema schema;

  // flag indicating whether the given dataset is binary or not
  private boolean isBinary=true;

  // value to be returned if a non-binary dataset is given to 
  // precision() or recall() methods
  private static final int ILLEGAL_VALUE=-1;

  private static final double UNINITIALIZED=-1;

  // actual data structure in which the examples are stored, along with
  // additional fields required for executing the tweaking
  private List<Row> tweakingTable=new ArrayList<Row>();

  // confusion matrix used for efficiently perform the tweaking
  Matrix cm=null;

  // logger for this class
  private static Logger log=Logger.getLogger(TweakedLearner.class);

  /**
   ******************************************************************** 
   * Public methods  
   ******************************************************************** 
   */

  // TweakedLearner constructor
  public TweakedLearner(BinaryClassifierLearner innerLearner,double beta){
    this.beta=beta;
    this.innerLearner=innerLearner;
  }

  /*
   * main method of the TweakedLearner class. Recieves a binary
   * training dataset and then:
   * 1. Trains on it, based on the innerLearner, namely its inherent
   *    binary classifier.
   * 2. Tweaks the classifier (or, more precisely, the model this
   *    classifier came up with), so that F_beta is maximized. This is
   *    done by finding a threshold, see details below.
   * 3. Creates and returns a new TweakedClassifier, with the original
   *    (inner) classifier and the threshold that was found.
   */
  public Classifier batchTrain(Dataset dataset){

    // make sure the dataset given is indeed binary
    this.schema=dataset.getSchema();
    isBinary=schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
    if(!isBinary) // make sure dataset is binary
    {
      throw new IllegalArgumentException(
          "Dataset given to TweakedLearner::batchTrain must be a binary dataset");
    }
    if(dataset.size()==0) // make sure dataset is not empty
    {
      throw new IllegalArgumentException(
          "Dataset given to TweakedLearner::batchTrain is empty");
    }
    this.m_dataset=dataset;

    // get the classifier resulting from training on the given dataset 
    BinaryClassifier bc=
        (BinaryClassifier)new DatasetClassifierTeacher(m_dataset)
            .train(innerLearner);

    // Initialize the data structure required for the tweaking. Please note
    // that the ExecuteTweaking() method assumes that the rows in this table
    // are sorted by descending score
    initializeTable();

    // Execute actual tweaking - figure out what threshold  works best on the 
    // given dataset w.r.t. F_beta
    double threshold=executeTweaking();

    return new TweakedClassifier(bc,threshold);
  }

  /**
   ******************************************************************** 
   * Getters & Setters  
   ******************************************************************** 
   */

  /**
   * @return Returns the beta.
   */
  public double getBeta(){
    return beta;
  }

  /**
   * @param beta The beta to set.
   */
  public void setBeta(double beta){
    this.beta=beta;
  }

  /**
   * @return Returns the innerLearner.
   */
  public BinaryClassifierLearner getInnerLearner(){
    return innerLearner;
  }

  /**
   * @param learner The innerLearner to set.
   */
  public void setInnerLearner(BinaryClassifierLearner learner){
    this.innerLearner=learner;
  }

  /**
   ******************************************************************** 
   * Private methods  
   ******************************************************************** 
   */

  // This method initializes tweakingTable, which the data structure used for tweaking
  // It loops over the examples in the given dataset, and insert them into the
  // the table. According to the needs of the tweaking process, the rows are then 
  // sorted by descending score (also called posWeight).
  private void initializeTable(){
    int counter=0;
    for(Iterator<Example> i=m_dataset.iterator();i.hasNext();counter++){
      Example ex=i.next();
      ClassLabel predicted=
          innerLearner.getBinaryClassifier().classification(ex);
      // add example into the tweaking data structure. note that the tweaked
      // prediction given during initialization is NEG for all examples !
      tweakingTable.add(new Row(ex.asInstance(),ex.getLabel(),predicted,
          ClassLabel.negativeLabel(-1.0)));

      // debug code
      //double score = innerLearner.getBinaryClassifier().score(ex);
      /*
      log.debug("Example number: "+ counter + 
          ", posWeight: " + predicted.posWeight() + 
          ", Score: " + score + 
          ", Label: " + ex.getLabel());
       */
    }

    // sort the table, after it was filled, by descending score
    sortByScore();
  }

  /*
   *  This method is the very heart of the tweaking process. It assumes that 
   * the tweakingTable data structure was initilized and filled, with a row 
   * for every example. It further assumes that all the examples were given an 
   * initial tweak_prediction of NEG and that they were sorted by descending score 
   * The method then:
   * 1. Initialize a confusion matrix, based on a NEG prediction to all examples.
   * 2. For every example, starting with the on ewith highest positive score:
   *    a. set the tweak_prediction to POS
   *    b. update the confusion matrix accordingly
   *    c. calculate precision, recall and F_beta with the new confusion matrix
   *       and fill these values in the tweakingtable data structure
   *   Please note, that in any such iteration, all the examples/rows above
   *   the current example (including itself) have a POS prediction, while all the
   *   examples/rows below the current example have a NEG prediction. 
   *   That is, we prectically evaluate the F_beta when the "dividing line" is on 
   *   the current example
   * 3. After all the rows/exmaples are handled, choose the row with the maximal F_beta
   * 4. Select the score of this row, or more precisely the average between this score
   *    and the next row's score, to be the threshold.
   * 5. Return this number as the threshold constituting the new TweakedClassifier.
   *      
   */
  private double executeTweaking(){
    double threshold=UNINITIALIZED;
    initConfusionMatrix();

    // for every row, find and fill the precision, recall and F_beta
    // Note, that each row examined is first set to POS
    for(int i=0;i<tweakingTable.size();++i){
      // set dummy prediction of current example to POS
      getRow(i).tweak_predicted=ClassLabel.positiveLabel(1.0);
      // update the confusion matrix based on this prediction change
      updateConfusionMatrix(i);

      // calculate the precision, recall and F_beta with the updated confusion matrix
      getRow(i).precision=getCurrentPrecision();
      getRow(i).recall=getCurrentRecall();
      getRow(i).F_beta=calculateFBeta(getRow(i).precision,getRow(i).recall);

      /*
      log.debug("row " + i + ", precision: " + getRow(i).precision
          + ", recall: " + getRow(i).recall + ", F_beta: " + getRow(i).F_beta 
          + ", score: " + getRow(i).orig_predicted.posWeight());
       */
    }
    // choose the threshold row, that is with maximal F_beta
    // translate its score into the returned threshold
    int index=maxFBetaEntry();

    // if the row that was found is the last row in the table (VERY unlikely),
    // set the threshold to be its score 
    if((index+1)==tweakingTable.size()){
      threshold=getRow(index).orig_predicted.posWeight();
    }else // otherwise, set it to be the average between this row's score and 
    { // the next row's score
      double maxRowScore=getRow(index).orig_predicted.posWeight();
      double nextRowScore=getRow(index+1).orig_predicted.posWeight();
      threshold=(maxRowScore+nextRowScore)/2;
    }
    log.debug("Threshold found: "+threshold+" (in row "+index+")");

    return threshold; // return the threshold that was found
  }

  /** 
   * Initializes the confusion matrix. This method is called in the first step
   * of the tweaking process. Please note that at this step, all the examples
   * are set to have a tweak_predited field of NEG class.
   */
  private void initConfusionMatrix(){
    String[] classes=getClasses();
    // count up the errors
    double[][] confused=new double[classes.length][classes.length];
    for(int i=0;i<tweakingTable.size();i++){
      Row row=getRow(i);
      confused[classIndexOf(row.actual)][classIndexOf(row.tweak_predicted)]++;
    }
    cm=new Matrix(confused);
  }

  /*
   * During the tweaking process, in each iteration a single example
   * is handled, so that its tweaked_prediction is changed from NEG to POS
   * This method receives the index (in th etweakingTable) of the current example
   * and updates the confusion matrix accordingly
   */
  private void updateConfusionMatrix(int index){
    Row row=getRow(index);
    int actual=classIndexOf(row.actual);
    int p=classIndexOf(ExampleSchema.POS_CLASS_NAME);
    int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME);

    // the confusion matrix (cm) is built as [actual][predicted]
    cm.values[actual][p]++;
    cm.values[actual][n]--;
  }

  // This method simply returns, given precision and recall, the value
  // of F_beta. It uses the "beta" data member of this class, to decide
  // which function is to be calculated
  // The formula used is:
  // F_beta = (beta+1) * precision * recall / 
  //       (beta * precision) + recall 
  //
  // See also:
  // <p>Reference: Jason D. M. Rennie,
  // <i>Derivation of the F-Measure</i>,
  // http://people.csail.mit.edu/jrennie/writing/fmeasure.pdf
  private double calculateFBeta(double precision,double recall){
    double divisor=((beta*precision)+recall);

    // in case a division by zero will occur, return F_beta=0.0 (instead of NaN)
    if(divisor==0.0){
      log.warn("TweakedLearner::calculateFBeta, divisor of F_beta is zero !!!");
      return 0.0;
    }
    // in case a division by NaN, return F_beta=0.0 (instead of NaN)
    if((new Double(divisor)).isNaN()){
      log
          .warn("TweakedLearner::calculateFBeta, divisor of F_beta is a NaN !!!");
      return 0.0;
    }

    return(((beta+1)*precision*recall)/divisor);
  }

  // This method returns the precision based on the current confusion matrix. 
  // Note that during the tweaking process the confusion matrix is iteratively updated 
  // Precision is defined as:
  // true_positive / (true_positive + false_positive)
  private double getCurrentPrecision(){
    if(!isBinary)
      return ILLEGAL_VALUE; // to be on the safe side

    int p=classIndexOf(ExampleSchema.POS_CLASS_NAME);
    int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME);

    // the confusion matrix (cm) is built as [actual][predicted]
    return cm.values[p][p]/(cm.values[p][p]+cm.values[n][p]);
  }

  // This method returns the recall based on the current confusion matrix. 
  // Note that during the tweaking process the confusion matrix is iteratively updated 
  // Recall is defined as:
  // true_positive / (true_positive + false_negative)
  private double getCurrentRecall(){
    if(!isBinary)
      return ILLEGAL_VALUE; // to be on the safe side

    int p=classIndexOf(ExampleSchema.POS_CLASS_NAME);
    int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME);

    // the confusion matrix (cm) is built as [actual][predicted]
    return cm.values[p][p]/(cm.values[p][p]+cm.values[p][n]);
  }

  /**
   ******************************************************************** 
   * Private convenience methods  
   ******************************************************************** 
   */
  // sort the tweakingTable, after it was filled, by descending score
  private void sortByScore(){
    Collections.sort(tweakingTable,new Comparator<Row>(){
      public int compare(Row a,Row b){
        return MathUtil.sign(b.orig_predicted.posWeight()-a.orig_predicted.posWeight());
      }
    });
  }

  /*
   * Returns the index (in tweakingTable) of the Row with maximal F_beta value
   */
  private int maxFBetaEntry(){
    double maxFBeta=ILLEGAL_VALUE; // initialize
    int maxIndex=(int)UNINITIALIZED; // index of the row with maximal F_beta

    for(int i=0;i<tweakingTable.size();++i){
      if(getRow(i).F_beta>maxFBeta){
        maxFBeta=getRow(i).F_beta;
        maxIndex=i;
      }
    }

    if(maxFBeta==ILLEGAL_VALUE){
      log
          .error("In TweakedLearner::maxFBetaEntry, maxFBeta has an illegal value");
    }

    return maxIndex;
  }

  private Row getRow(int i){
    return (Row)tweakingTable.get(i);
  }

  private String[] getClasses(){
    return schema.validClassNames();
  }

  private int classIndexOf(ClassLabel classLabel){
    return classIndexOf(classLabel.bestClassName());
  }

  private int classIndexOf(String classLabelName){
    return schema.getClassIndex(classLabelName);
  }

  // debug method - simply dumps the tweakingTable data structure to stdout
//  private void printTable(){
//    for(int i=0;i<tweakingTable.size();++i){
//      System.out.println("Row number "+i+": "+getRow(i));
//    }
//  }

  /**
   ******************************************************************** 
   ******************************************************************** 
   * This class represents the information, about a single example,
   * needed for executing the tweaking:
   * 1. The example itself
   * 2. Its true label/class. Indeed this field can be accessed every time
   *    by using example.getLabel(), but for convenience it is stored in the table
   * 3. The predicted class (orig_predicted), as given by the original 
   *     (untweaked) classifier.
   * 4. A dummy prediction (tweak_predicted), which is used in the actual 
   *    tweaking process. During construction, all rows are initialized as NEG examples, 
   *    commensurate with the way the tweaking process is executed.
   *  
   * Please note that during the tweaking process, examples that were predicted
   * by the original (untweaked) classifier as POS can have a prediction of NEG,
   * and vice versa.
   * 
   * Note also, that the actual score for an example is given using 
   * predicted.posWeight(), where posWeight>0 means the original prediction 
   * of the untweaked classifier was that this example is of a POSITIVE class,
   * and posWeight<0 means NEGATIVE class.
   * 
   * In addition, for the actual tweaking process, 3 fields are 
   * maintained for each example/row:
   * 5. Precision
   * 6. Recall
   * 7. F_beta value
   * 
   * See the documentation of the actual tweaking method, ExecuteTweaking(),
   * for further details
   ******************************************************************** 
   ******************************************************************** 
   */
  private static class Row implements Serializable{

    private static final long serialVersionUID=-4069980043842319180L;

    transient public Instance instance=null; // the example

    public ClassLabel actual; // true label

    public ClassLabel orig_predicted; // predicted label - see documentation above

    public ClassLabel tweak_predicted; // temporary prediction, for tweaking process

    public double precision=UNINITIALIZED;

    public double recall=UNINITIALIZED;

    public double F_beta=UNINITIALIZED;

    public Row(Instance i,ClassLabel a,ClassLabel orig_p,ClassLabel tweak_p){
      instance=i;
      actual=a;
      orig_predicted=orig_p;
      tweak_predicted=tweak_p;
    }

    public String toString(){
      return orig_predicted+"\t"+actual+"\t"+instance;
    }
  }

  /** 
   ******************************************************************** 
   ******************************************************************** 
   * A Tweaked Classifier, with an optimization of the precision vs. recall
   * Please note that this is an internal class of the TweakedLearner class.
   * It is constructed and returned by the TweakedLearner, based on 
   * an original untweaked binary clasifer, and a threshold which was found
   * to optimized precision vs. recall
   * 
   * @author Giora Unger
   * Created on May 19, 2005
   ******************************************************************** 
   ******************************************************************** 
   */
  public static class TweakedClassifier extends BinaryClassifier implements
      Serializable,Visible{

    static private final long serialVersionUID=20080128L;

    private double m_threshold;

    private BinaryClassifier m_classifier;

    public TweakedClassifier(BinaryClassifier classifier,double threshold){
      m_classifier=classifier;
      m_threshold=threshold;
    }

    public double score(Instance instance){
      return m_classifier.score(instance)-m_threshold;
    }

    /* (non-Javadoc)
     * @see edu.cmu.minorthird.util.gui.Visible#toGUI()
     * 
     * Shows the original (untweaked) classifier, and the threshold 
     * that was found
     * Code was copied from file CMM.java and adjusted 
     */
    public Viewer toGUI(){
      final Viewer v=new ComponentViewer(){
        static final long serialVersionUID=20080128L;
        public JComponent componentFor(Object o){
          TweakedClassifier c=(TweakedClassifier)o;
          JPanel mainPanel=new JPanel();
          mainPanel.setLayout(new BorderLayout());
          mainPanel.add(new JLabel("Optimal threshold for TweakedClassifier="+
              c.m_threshold),BorderLayout.NORTH);
          mainPanel.add(new JLabel("Original classifier before tweaking:"),
              BorderLayout.CENTER);
          Viewer subView=new SmartVanillaViewer(c.m_classifier);
          subView.setSuperView(this);
          mainPanel.add(subView,BorderLayout.SOUTH);
          mainPanel.setBorder(new TitledBorder("TweakedClassifier class"));
          return new JScrollPane(mainPanel);
        }
      };
      v.setContent(this);
      return v;
    }

    /* (non-Javadoc)
     * @see edu.cmu.minorthird.classify.Classifier#explain(edu.cmu.minorthird.classify.Instance)
     */
    public String explain(Instance instance){
      StringBuffer buf=new StringBuffer("");
      buf.append("Explanation of original untweaked classifier:\n");
      buf.append(m_classifier.explain(instance));
      buf.append("\nAdjusted score after tweaking = "+score(instance));
      return buf.toString();
    }

    public Explanation getExplanation(Instance instance){
      Explanation.Node top=new Explanation.Node("TweakedLearner Explanation");
      Explanation.Node orig=
          new Explanation.Node("Explanation of original untweaked classifier");
      Explanation.Node origEx=
          m_classifier.getExplanation(instance).getTopNode();
      orig.add(origEx);
      top.add(orig);
      Explanation.Node adjusted=
          new Explanation.Node("\nAdjusted score after tweaking = "+
              score(instance));
      top.add(adjusted);
      Explanation ex=new Explanation(top);
      return ex;
    }
  }

  /**
   ******************************************************************** 
   ******************************************************************** 
   * Main method for testing purposes 
   ******************************************************************** 
   ******************************************************************** 
   */
  public static void main(String[] args){
    System.out.println("Started the test program for TweakedLearner");
    NaiveBayes nb=new NaiveBayes();
    new TweakedLearner(nb,3.0);
    System.out.println("Created a TweakedLearner");
  }

}
java2s.com  | Contact Us | Privacy Policy
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.