tagger.Supertagger.java Source code

Java tutorial

Introduction

Here is the source code for tagger.Supertagger.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package tagger;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import ling.Sentence;
import ling.Token;
import static org.apache.commons.math3.linear.MatrixUtils.createRealVector;
import org.apache.commons.math3.linear.RealVector;

/**
 *
 * @author hiroki
 */

final public class Supertagger {
    final private int BEAM_WIDTH;
    final private int N_LABELS;
    final private boolean IS_HASH;
    final private MatrixPerceptron perceptron;
    final private SequentialFeaturizer featurizer;
    final private ArrayList<Integer> trainIndex = new ArrayList<>();
    private Sentence[] trainData;
    private double[][] weight;
    private int[][][] cacheFeatureID;
    private Hypothesis[] cacheOracleHypo;

    float total = 0.0f;
    float correct = 0.0f;

    public Supertagger(int beamWidth, int numLabels, int weightSize, int windowSize, boolean isHash) {
        BEAM_WIDTH = beamWidth;
        N_LABELS = numLabels;
        IS_HASH = isHash;
        perceptron = new MatrixPerceptron(numLabels, weightSize);
        featurizer = new SequentialFeaturizer(weightSize, windowSize, IS_HASH);
    }

    final public void train() {
        Collections.shuffle(trainIndex);

        total = 0.0f;
        correct = 0.0f;

        for (int i = 0; i < trainData.length; ++i) {
            if (i % 1000 == 0 && i != 0)
                System.out.print(String.format("%d ", i));

            int index = trainIndex.get(i);
            Sentence sent = trainData[index];
            int[][] featureID = cacheFeatureID[index];
            Hypothesis oracleHypo = cacheOracleHypo[index];
            Hypothesis[] reversedOracleHypo = oracleHypo.reverse(sent.N_TOKENS);
            Hypothesis[] beam = setMaxViolationPoint(sent, reversedOracleHypo, featureID);
            perceptron.updateWeights();
            checkAccuracy(oracleHypo, beam[0]);
        }

        System.out.println(String.format("\n\tacc:%f  crr:%f  ttl:%f", correct / total, correct, total));
    }

    final public void train(Sentence[] trainSents) {
        trainData = trainSents;
        cacheFeatureID = new int[trainData.length][][];
        cacheOracleHypo = new Hypothesis[trainData.length];

        for (int i = 0; i < trainData.length; ++i)
            trainIndex.add(i);
        Collections.shuffle(trainIndex);

        total = 0.0f;
        correct = 0.0f;

        for (int i = 0; i < trainData.length; ++i) {
            if (i % 1000 == 0 && i != 0)
                System.out.print(String.format("%d ", i));

            int index = trainIndex.get(i);
            Sentence sent = trainData[index];
            int[][] featureID = getFeatureID(sent);
            Hypothesis oracleHypo = getOracleHypo(sent, featureID);
            Hypothesis[] rOracleHypo = oracleHypo.reverse(sent.N_TOKENS);
            Hypothesis[] beam = setMaxViolationPoint(sent, rOracleHypo, featureID);
            perceptron.updateWeights();
            checkAccuracy(oracleHypo, beam[0]);
            cacheFeatureID[index] = featureID;
            cacheOracleHypo[index] = oracleHypo;
        }

        System.out.println(String.format("\n\tacc:%f  crr:%f  ttl:%f", correct / total, correct, total));
    }

    final public void test(Sentence[] testSents, String outputFileName) throws IOException {
        averageWeight();
        Hypothesis[] results = new Hypothesis[testSents.length];

        total = 0.0f;
        correct = 0.0f;

        for (int i = 0; i < testSents.length; ++i) {
            if (i % 1000 == 0 && i != 0)
                System.out.print(String.format("%d ", i));

            Sentence sent = testSents[i];
            Hypothesis[] beam = decode(sent);
            checkAccuracy(sent, beam[0]);
            results[i] = beam[0];
        }

        System.out.println(String.format("\n\tacc:%f  crr:%f  ttl:%f", correct / total, correct, total));

        if (outputFileName != null)
            outputResultingText(outputFileName, testSents, results);
    }

    private Hypothesis[] decode(Sentence sent) {
        Hypothesis[] beam = new Hypothesis[BEAM_WIDTH];
        beam[0] = new Hypothesis();

        for (int t = 0; t < sent.N_TOKENS; ++t) {
            beam = beamSearch(t, beam, sent);
        }

        return beam;

    }

    private void checkAccuracy(Hypothesis oracle, Hypothesis system) {
        Hypothesis oh = oracle;
        Hypothesis sh = system;
        while (oh.label > -1) {
            if (oh.label == sh.label)
                correct += 1.0;
            total += 1.0;

            oh = oh.prevHypo;
            sh = sh.prevHypo;
        }
    }

    private void checkAccuracy(Sentence sent, Hypothesis system) {
        Hypothesis[] sh = system.reverse(sent.N_TOKENS);
        for (int t = 0; t < sent.N_TOKENS; ++t) {
            int label = sh[t].label;
            if (label == sent.stags[t])
                correct += 1.0;
            total += 1.0;
        }
    }

    private void outputResultingText(String fn, Sentence[] sents, Hypothesis[] results) throws IOException {
        PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter(fn + ".conll")));

        for (int i = 0; i < results.length; ++i) {
            Sentence sent = sents[i];
            Hypothesis[] h = results[i].reverse(sent.N_TOKENS);

            for (int j = 0; j < h.length; ++j) {
                Token t = sent.tokens[j];
                String stag = Sentence.stagDict.getKey(h[j].label);
                String text = String.format("%d\t%s\t_\t%s\t%s\t%s\t%d\t%s\t_\t_", t.INDEX, t.O_FORM, t.O_CPOS,
                        t.O_POS, stag, t.O_HEAD, t.O_LABEL);
                pw.println(text);
            }
            pw.println();
        }
        pw.close();
    }

    public void outputGoldStagText(String fn, Sentence[] sents) throws IOException {
        PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter(fn + ".conll")));

        for (int i = 0; i < sents.length; ++i) {
            Sentence sent = sents[i];

            for (int j = 0; j < sent.size(); ++j) {
                Token t = sent.tokens[j];
                String stag = Sentence.stagDict.getKey(sent.stags[j]);
                String text = String.format("%d\t%s\t_\t%s\t%s\t%s\t%d\t%s\t_\t_", t.INDEX, t.O_FORM, t.O_CPOS,
                        t.O_POS, stag, t.O_HEAD, t.O_LABEL);
                pw.println(text);
            }
            pw.println();
        }
        pw.close();
    }

    private Hypothesis getOracleHypo(Sentence sent, int[][] featureID) {
        return updateHypo(0, sent, featureID, new Hypothesis());
    }

    private Hypothesis updateHypo(int t, Sentence sent, int[][] featureID, Hypothesis prevHypo) {
        if (t == sent.N_TOKENS)
            return prevHypo;

        int label = sent.stags[t];
        int[] featureID_t = featureID[t];
        int[] markovFeatureID = featurizer.featurize(sent, t, prevHypo.label, prevHypo.prevHypo.label);
        double score = getScore(featureID_t, label) + getScore(markovFeatureID, label);

        return updateHypo(t + 1, sent, featureID,
                new Hypothesis(label, score, featureID_t, markovFeatureID, prevHypo));
    }

    private double getScore(int[] features, int label) {
        return perceptron.getScore(features, label);
    }

    private Hypothesis[] setMaxViolationPoint(Sentence sent, Hypothesis[] oracleHypo, int[][] featureID) {
        Hypothesis[] beam = new Hypothesis[BEAM_WIDTH];
        beam[0] = new Hypothesis();

        for (int t = 0; t < sent.N_TOKENS; ++t) {
            beam = beamSearch(t, beam, sent, featureID);
            perceptron.setMaxViolationPoint(t, oracleHypo[t], beam[0]);
        }

        return beam;
    }

    private Hypothesis[] beamSearch(int t, Hypothesis[] beam, Sentence sent, int[][] featureID) {
        HypoComparator orderedHypoQueue = new HypoComparator(BEAM_WIDTH);

        int[] featureID_t = featureID[t];
        RealVector featureScore = perceptron.getScore(featureID_t);
        int[][] markovFeature = null;
        String[] markovStrFeature = null;
        if (IS_HASH)
            markovFeature = featurizer.getMarkovFeature(sent, t);
        else
            markovStrFeature = featurizer.getStrMarkovFeature(sent, t);

        for (int k = 0; k < BEAM_WIDTH; ++k) {
            Hypothesis hypo = beam[k];

            if (hypo == null)
                break;

            int[] markovFeatureID;
            if (IS_HASH)
                markovFeatureID = featurizer.getFeatureID(markovFeature, hypo.label, hypo.prevHypo.label);
            else
                markovFeatureID = featurizer.getFeatureID(markovStrFeature, hypo.label, hypo.prevHypo.label, false);
            RealVector markovFeatureScore = perceptron.getScore(markovFeatureID);
            double[] score = featureScore.add(markovFeatureScore).toArray();

            for (int n = 0; n < N_LABELS; ++n)
                orderedHypoQueue.addSort(n, score[n], featureID_t, markovFeatureID, hypo);
        }

        for (int k = 0; k < BEAM_WIDTH; ++k) {
            beam[k] = orderedHypoQueue.get(k);
            if (k == orderedHypoQueue.size() - 1)
                break;
        }

        return beam;
    }

    private Hypothesis[] beamSearch(int t, Hypothesis[] beam, Sentence sent) {
        HypoComparator orderedHypoQueue = new HypoComparator(BEAM_WIDTH);

        int[] featureID_t;
        if (IS_HASH)
            featureID_t = featurizer.getFeatureID(featurizer.getFeature(sent, t));
        else
            featureID_t = featurizer.getFeatureID(featurizer.getStrFeature(sent, t), false);

        int[][] markovFeature = null;
        String[] markovStrFeature = null;
        if (IS_HASH)
            markovFeature = featurizer.getMarkovFeature(sent, t);
        else
            markovStrFeature = featurizer.getStrMarkovFeature(sent, t);

        RealVector featureScore = getScore(featureID_t);

        for (int k = 0; k < BEAM_WIDTH; ++k) {
            Hypothesis hypo = beam[k];

            if (hypo == null)
                break;

            int[] markovFeatureID;
            if (IS_HASH)
                markovFeatureID = featurizer.getFeatureID(markovFeature, hypo.label, hypo.prevHypo.label);
            else
                markovFeatureID = featurizer.getFeatureID(markovStrFeature, hypo.label, hypo.prevHypo.label, true);
            RealVector markovFeatureScore = getScore(markovFeatureID);
            double[] score = featureScore.add(markovFeatureScore).toArray();

            for (int n = 0; n < N_LABELS; ++n)
                orderedHypoQueue.addSort(n, score[n], featureID_t, markovFeatureID, hypo);
        }

        for (int k = 0; k < BEAM_WIDTH; ++k) {
            beam[k] = orderedHypoQueue.get(k);
            if (k == orderedHypoQueue.size() - 1)
                break;
        }

        return beam;
    }

    private RealVector getScore(int[] featureID) {
        RealVector score = createRealVector(new double[N_LABELS]);
        for (int i = 0; i < featureID.length; ++i)
            score = score.add(createRealVector(weight[featureID[i]]));
        return score;
    }

    private int[][] getFeatureID(Sentence sent) {
        int[][] featureID = new int[sent.N_TOKENS][];
        for (int t = 0; t < sent.N_TOKENS; ++t)
            featureID[t] = featurizer.featurize(sent, t);
        return featureID;
    }

    private void averageWeight() {
        double t = perceptron.t;
        double[][] w = perceptron.weight;
        double[][] aw = perceptron.aweight;

        this.weight = new double[w.length][w[0].length];

        for (int i = 0; i < w.length; i++) {
            for (int j = 0; j < w[i].length; j++) {
                this.weight[i][j] = w[i][j] - aw[i][j] / t;
            }
        }

    }
}