com.davidbracewell.ml.sequence.decoder.LinearViterbi.java Source code

Java tutorial

Introduction

Here is the source code for com.davidbracewell.ml.sequence.decoder.LinearViterbi.java

Source

/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.ml.sequence.decoder;

import com.davidbracewell.conversion.Val;
import com.davidbracewell.ml.classification.ClassificationResult;
import com.davidbracewell.ml.Feature;
import com.davidbracewell.ml.sequence.Sequence;
import com.davidbracewell.ml.sequence.SequenceModel;
import com.davidbracewell.ml.sequence.linear.LinearSequenceModel;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.primitives.Doubles;

import java.io.Serializable;

/**
 * The type Linear viterbi.
 *
 * @author David B. Bracewell
 */
public class LinearViterbi<V> extends Viterbi<V> {

    private static final long serialVersionUID = 555426483507549520L;
    private final int beamSize;

    /**
     * Instantiates a new Linear viterbi.
     *
     * @param beamSize the beam size
     */
    public LinearViterbi(int beamSize) {
        this.beamSize = beamSize;
    }

    @Override
    public double[] decode(SequenceModel<V> raw, Sequence<V> sequence) {
        LinearSequenceModel<V> model = Val.of(raw).cast();
        final Feature classFeature = model.getTargetFeature();
        final int numStates = classFeature.alphabetSize();

        MinMaxPriorityQueue<State> beam = MinMaxPriorityQueue.maximumSize(beamSize).create();
        ClassificationResult result = model.classifyItem(0, sequence, new double[0]);
        for (int ci = 0; ci < numStates; ci++) {
            if (isValidStartTag(classFeature.valueAtIndex(ci))
                    && isValidTag(classFeature.valueAtIndex(ci), sequence.getData(0))) {
                beam.add(new State(Math.log(result.getConfidence(ci)), ci, null, 0));
            }
        }

        MinMaxPriorityQueue<State> tempBeam = MinMaxPriorityQueue.maximumSize(beamSize).create();
        for (int i = 1; i < sequence.length(); i++) {
            while (!beam.isEmpty()) { // go through all the previous states
                State state = beam.removeFirst();
                String previousTag = classFeature.valueAtIndex(state.tag);
                result = model.classifyItem(i, sequence, state.labels());
                for (int ci = 0; ci < numStates; ci++) {
                    if (isValidTransition(previousTag, classFeature.valueAtIndex(ci))
                            && ((i + 1 < sequence.length()) || isValidEndTag(classFeature.valueAtIndex(ci)))
                            && isValidTag(classFeature.valueAtIndex(ci), sequence.getData(i))) {
                        tempBeam.add(
                                new State(state.probability + Math.log(result.getConfidence(ci)), ci, state, i));
                    }
                }
            }
            beam.addAll(tempBeam);
            tempBeam.clear();
        }

        return beam.remove().labels();
    }

    private static class State implements Comparable<State>, Serializable {

        private static final long serialVersionUID = -6696246937456087462L;
        private final double probability;
        private final int tag;
        private final State prev;
        private final int index;

        private State(double probability, int tag, State prev, int index) {
            this.probability = probability;
            this.tag = tag;
            this.prev = prev;
            this.index = index;
        }

        @Override
        public int compareTo(State o) {
            return -Doubles.compare(this.probability, o.probability);
        }

        @Override
        public String toString() {
            return "(" + tag + ", " + probability + ")";
        }

        /**
         * Labels double [ ].
         *
         * @return the double [ ]
         */
        public double[] labels() {
            double[] labels = new double[index + 1];
            State s = this;
            while (s != null) {
                labels[s.index] = s.tag;
                s = s.prev;
            }
            return labels;
        }

        /**
         * Labels double [ ].
         *
         * @param t0 the t 0
         * @return the double [ ]
         */
        public double[] labels(double t0) {
            double[] labels = new double[index + 2];
            labels[labels.length - 1] = t0;
            State s = this;
            while (s != null) {
                labels[s.index] = s.tag;
                s = s.prev;
            }
            return labels;
        }

    }

}//END OF LinearViterbi