edu.cmu.cs.lti.ark.fn.parsing.Decoding.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.cs.lti.ark.fn.parsing.Decoding.java

Source

/*******************************************************************************
 * Copyright (c) 2011 Dipanjan Das 
 * Language Technologies Institute, 
 * Carnegie Mellon University, 
 * All Rights Reserved.
 * 
 * Decoding.java is part of SEMAFOR 2.0.
 * 
 * SEMAFOR 2.0 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or 
 * (at your option) any later version.
 * 
 * SEMAFOR 2.0 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. 
 * 
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.0.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn.parsing;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.collect.*;
import edu.cmu.cs.lti.ark.util.FileUtil;
import edu.cmu.cs.lti.ark.util.ds.Scored;
import org.pcollections.HashTreePMap;
import org.pcollections.PMap;

import java.util.*;

import static com.google.common.collect.ImmutableList.copyOf;
import static com.google.common.collect.Iterables.transform;
import static edu.cmu.cs.lti.ark.util.ds.Scored.scored;
import static java.lang.Integer.parseInt;
import static java.lang.Math.min;

/**
 * Predict spans for roles using beam search.
 */
public class Decoding {
    private static final int DEFAULT_BEAM_WIDTH = 100;
    private static final Joiner TAB_JOINER = Joiner.on("\t");

    protected double[] modelWeights;

    /** 0-indexed. Both ends inclusive. Null span is represented as [-1,-1]. */
    public static class Span implements Comparable<Span> {
        public final int start;
        public final int end;

        public Span(int start, int end) {
            this.start = start;
            this.end = end;
        }

        public boolean isEmpty() {
            return start == -1;
        }

        /** Determines whether this and the other span overlap */
        public boolean overlaps(Span other) {
            // empty spans can't overlap with anything
            if (isEmpty() || other.isEmpty())
                return false;
            if (start < other.start) {
                return end >= other.start;
            } else {
                return other.end >= start;
            }
        }

        @Override
        public int compareTo(Span other) {
            return ComparisonChain.start().compare(start, other.start).compare(end, other.end).result();
        }

        @Override
        public String toString() {
            return (start == end) ? ("" + start) : (start + ":" + end);
        }
    }

    /** An assignment of spans to roles of a particular frame */
    public static class RoleAssignments implements Comparable<RoleAssignments> {
        private final static Function<Map.Entry<String, Span>, String> JOIN_ENTRY = new Function<Map.Entry<String, Span>, String>() {
            @Override
            public String apply(Map.Entry<String, Span> input) {
                return TAB_JOINER.join(input.getKey(), input.getValue().toString());
            }
        };
        private final PMap<String, Span> nonNullAssignments;
        private final PMap<String, Span> nullAssignments;

        public RoleAssignments(PMap<String, Span> nonNullAssignments, PMap<String, Span> nullAssignments) {
            this.nonNullAssignments = nonNullAssignments;
            this.nullAssignments = nullAssignments;
        }

        public RoleAssignments() {
            this(HashTreePMap.<String, Span>empty(), HashTreePMap.<String, Span>empty());
        }

        public RoleAssignments plus(String key, Span value) {
            if (value.isEmpty()) {
                return new RoleAssignments(nonNullAssignments, nullAssignments.plus(key, value));
            } else {
                return new RoleAssignments(nonNullAssignments.plus(key, value), nullAssignments);
            }
        }

        private Map<String, Span> getNonNullAssignments() {
            return nonNullAssignments;
        }

        /** Determines whether the given span overlaps with any of our spans */
        private boolean overlaps(Span otherSpan) {
            if (otherSpan.isEmpty())
                return false;
            for (Span span : getNonNullAssignments().values()) {
                if (span.overlaps(otherSpan))
                    return true;
            }
            return false;
        }

        @Override
        public String toString() {
            return TAB_JOINER.join(transform(getNonNullAssignments().entrySet(), JOIN_ENTRY));
        }

        @Override
        public int compareTo(RoleAssignments other) {
            return Ordering.arbitrary().compare(this, other);
        }
    }

    /** A sorted list of spans and their log score for a particular role */
    public static class CandidatesForRole extends TreeSet<Scored<Span>> {
    }

    public Decoding(double[] modelWeights) {
        this.modelWeights = modelWeights;
    }

    public static Decoding fromFile(String modelFile, String alphabetFile) {
        return new Decoding(readModel(modelFile, alphabetFile));
    }

    protected static double[] readModel(String modelFile, String alphabetFile) {
        final Scanner localsc = FileUtil.openInFile(alphabetFile);
        final int numLocalFeatures;
        try {
            numLocalFeatures = localsc.nextInt() + 1;
        } finally {
            localsc.close();
        }
        final Scanner scanner = FileUtil.openInFile(modelFile);
        final double[] modelWeights = new double[numLocalFeatures];
        try {
            for (int i = 0; i < numLocalFeatures; i++) {
                modelWeights[i] = Double.parseDouble(scanner.nextLine());
            }
        } finally {
            scanner.close();
        }
        return modelWeights;
    }

    public List<String> decodeAll(List<FrameFeatures> frameFeaturesList, List<String> frameLines, int offset,
            int kBestOutput) {
        final ArrayList<String> results = new ArrayList<String>();
        for (int i = 0; i < frameFeaturesList.size(); i++) {
            final FrameFeatures frameFeatures = frameFeaturesList.get(i);
            final String initialDecisionLine = getInitialDecisionLine(frameLines.get(i), offset);
            final List<Scored<RoleAssignments>> predictions = getPredictions(frameFeatures, kBestOutput);
            final List<String> predictionLines = Lists.newArrayList();
            for (int j = 0; j < predictions.size(); j++) {
                final Scored<RoleAssignments> prediction = predictions.get(j);
                predictionLines.add(formatPrediction(j, initialDecisionLine, prediction.value, prediction.score));
            }
            results.add(Joiner.on("\n").join(predictionLines));
        }
        return results;
    }

    private String formatPrediction(int rank, String initialDecisionLine, RoleAssignments assignments,
            double score) {
        return TAB_JOINER.join(rank, score, assignments.getNonNullAssignments().size() + 1, // includes the target
                initialDecisionLine, assignments.toString());
    }

    /**
     * Calculates the sum of the weights of firing features.
     *
     * @param feats indexes of firing features
     * @param weights an array of weights into which feats indexes
     * @return the sum of the weights of firing features
     */
    public static double getWeightSum(int[] feats, double[] weights) {
        // the 0th coordinate is the bias; it always fires
        double weightSum = weights[0];
        for (int feat : feats) {
            if (feat != 0)
                weightSum += weights[feat];
        }
        return weightSum;
    }

    /** Adds 'offset' to the sentence field and discards the 1st two fields. */
    protected String getInitialDecisionLine(String frameLine, int offset) {
        String[] frameTokens = frameLine.split("\t");
        frameTokens[7] = "" + (parseInt(frameTokens[7]) + offset);
        return TAB_JOINER.join(copyOf(frameTokens).subList(3, frameTokens.length)).trim();
    }

    private static <T> List<T> safeTruncate(List<T> list, int beamWidth) {
        return list.subList(0, min(list.size(), beamWidth));
    }

    /**
     * Decode, respecting the constraint that arguments do not overlap.
     * Find the k (approximately) best configurations of non-overlapping role-filling spans using beam search.
     *
     * @param frameFeatures features for the given frame
     * @param kBestOutput the number of top configurations we should return
     * @return a list of Strings encoding the best k configurations of spans for all roles of the given frame
     */
    public List<Scored<RoleAssignments>> getPredictions(FrameFeatures frameFeatures, int kBestOutput) {
        // group by role
        final Map<String, CandidatesForRole> candidatesAndScoresByRole = scoreCandidatesForRoles(
                frameFeatures.fElements, frameFeatures.fElementSpansAndFeatures);

        // run beam search to find the (approximately) k-best non-overlapping configurations
        // our beam
        List<Scored<RoleAssignments>> currentBeam = Lists.newArrayList(scored(new RoleAssignments(), 0.0));
        // run beam search
        for (String roleName : candidatesAndScoresByRole.keySet()) {
            final PriorityQueue<Scored<RoleAssignments>> newBeam = Queues.newPriorityQueue();
            for (Scored<Span> candidate : candidatesAndScoresByRole.get(roleName)) {
                for (Scored<RoleAssignments> partialAssignment : currentBeam) {
                    final double newScore = partialAssignment.score + candidate.score; // multiply in log-space
                    if (newBeam.size() >= DEFAULT_BEAM_WIDTH && newScore <= newBeam.peek().score)
                        break;
                    if (!partialAssignment.value.overlaps(candidate.value)) {
                        final RoleAssignments newAssignment = partialAssignment.value.plus(roleName,
                                candidate.value);
                        newBeam.add(scored(newAssignment, newScore));
                    }
                    if (newBeam.size() > DEFAULT_BEAM_WIDTH)
                        newBeam.poll();
                }
            }
            currentBeam = copyOf(newBeam);
            //System.out.println("Considering " + roleName);
            //System.out.println("Beam grew to " + newBeam.size());
            //System.out.println("Current best: " + newBeam.first().value + " " + newBeam.first().score);
        }
        return safeTruncate(currentBeam, kBestOutput);
    }

    private Map<String, CandidatesForRole> scoreCandidatesForRoles(List<String> roleNames,
            List<SpanAndCorrespondingFeatures[]> featuresList) {
        final Map<String, CandidatesForRole> results = Maps.newHashMap();
        for (int i = 0; i < featuresList.size(); i++) {
            final String roleName = roleNames.get(i);
            final CandidatesForRole candidatesForRole = new CandidatesForRole();
            for (SpanAndCorrespondingFeatures spanAndFeatures : featuresList.get(i)) {
                final Span span = new Span(spanAndFeatures.span[0], spanAndFeatures.span[1]);
                final double logScore = getWeightSum(spanAndFeatures.features, modelWeights);
                candidatesForRole.add(scored(span, logScore));
            }
            results.put(roleName, candidatesForRole);
        }
        return results;
    }

    public void wrapUp() {
        /* no op unless overridden */ }
}