com.davidbracewell.ml.sequence.decoder.HMMViterbi.java Source code

Java tutorial

Introduction

Here is the source code for com.davidbracewell.ml.sequence.decoder.HMMViterbi.java

Source

/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.ml.sequence.decoder;

import com.davidbracewell.conversion.Val;
import com.davidbracewell.ml.Instance;
import com.davidbracewell.ml.Feature;
import com.davidbracewell.ml.sequence.Sequence;
import com.davidbracewell.ml.sequence.SequenceModel;
import com.davidbracewell.ml.sequence.hmm.FirstOrderHMM;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.primitives.Doubles;

/**
 * The type HMM beam viterbi.
 *
 * @author David B. Bracewell
 */
public class HMMViterbi<V> extends Viterbi<V> {

    private static final long serialVersionUID = -2719248351039964585L;
    private final int beamSize;

    /**
     * Instantiates a new HMM beam viterbi.
     *
     * @param beamSize the beam size
     */
    public HMMViterbi(int beamSize) {
        this.beamSize = beamSize;
    }

    @Override
    public double[] decode(SequenceModel<V> raw, Sequence<V> seq) {
        FirstOrderHMM<V> model = Val.of(raw).cast();
        int len = seq.length();
        Feature classFeature = model.getTargetFeature();
        int NC = classFeature.alphabetSize();
        MinMaxPriorityQueue<State> beam = MinMaxPriorityQueue.maximumSize(beamSize).create();
        MinMaxPriorityQueue<State> tempBeam = MinMaxPriorityQueue.maximumSize(beamSize).create();

        Instance firstInstance = seq.generateInstance(0, new double[] { 0 });
        for (int ci = 0; ci < NC; ci++) {
            if (isValidStartTag(classFeature.valueAtIndex(ci))
                    && isValidTag(classFeature.valueAtIndex(ci), seq.getData(0))) {
                beam.add(new State(model.pi(ci) + model.beta(ci, firstInstance), ci, null));
            }
        }

        for (int i = 1; i < len; i++) {
            for (int ci = 0; ci < NC; ci++) {
                String thisTag = classFeature.valueAtIndex(ci);
                if (!isValidTag(thisTag, seq.getData(i))) {
                    continue;
                }
                if ((i + 1 < seq.length()) || isValidEndTag(thisTag)) {
                    double pInst = model.beta(ci, seq.generateInstance(i, new double[] { ci }));
                    for (State state : beam) {
                        if (isValidTransition(classFeature.valueAtIndex(state.tag), thisTag)) {
                            tempBeam.add(new State(state.probability + //previous probability
                                    model.alpha(state.tag, ci) + //transition probability
                                    pInst, //probability of the vector given the tag
                                    ci, state));
                        }
                    }
                }
            }
            MinMaxPriorityQueue<State> t = beam;
            beam = tempBeam;
            tempBeam = t;
            tempBeam.clear();
        }

        State max = beam.remove();
        double[] prediction = new double[len];
        for (int i = len - 1; i >= 0; i--) {
            prediction[i] = max.tag;
            max = max.prev;
        }

        return prediction;
    }

    private static class State implements Comparable<State> {

        private final double probability;
        private final int tag;
        private final State prev;

        private State(double probability, int tag, State prev) {
            this.probability = probability;
            this.tag = tag;
            this.prev = prev;
        }

        @Override
        public int compareTo(State o) {
            return -Doubles.compare(this.probability, o.probability);
        }

        @Override
        public String toString() {
            return "(" + tag + ", " + probability + ")";
        }
    }

}//END OF HMMViterbi