Java tutorial
/* * (c) 2005 David B. Bracewell * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.davidbracewell.ml.sequence.decoder; import com.davidbracewell.conversion.Val; import com.davidbracewell.ml.Instance; import com.davidbracewell.ml.Feature; import com.davidbracewell.ml.sequence.Sequence; import com.davidbracewell.ml.sequence.SequenceModel; import com.davidbracewell.ml.sequence.hmm.FirstOrderHMM; import com.google.common.collect.MinMaxPriorityQueue; import com.google.common.primitives.Doubles; /** * The type HMM beam viterbi. * * @author David B. Bracewell */ public class HMMViterbi<V> extends Viterbi<V> { private static final long serialVersionUID = -2719248351039964585L; private final int beamSize; /** * Instantiates a new HMM beam viterbi. * * @param beamSize the beam size */ public HMMViterbi(int beamSize) { this.beamSize = beamSize; } @Override public double[] decode(SequenceModel<V> raw, Sequence<V> seq) { FirstOrderHMM<V> model = Val.of(raw).cast(); int len = seq.length(); Feature classFeature = model.getTargetFeature(); int NC = classFeature.alphabetSize(); MinMaxPriorityQueue<State> beam = MinMaxPriorityQueue.maximumSize(beamSize).create(); MinMaxPriorityQueue<State> tempBeam = MinMaxPriorityQueue.maximumSize(beamSize).create(); Instance firstInstance = seq.generateInstance(0, new double[] { 0 }); for (int ci = 0; ci < NC; ci++) { if (isValidStartTag(classFeature.valueAtIndex(ci)) && isValidTag(classFeature.valueAtIndex(ci), seq.getData(0))) { beam.add(new State(model.pi(ci) + model.beta(ci, firstInstance), ci, null)); } } for (int i = 1; i < len; i++) { for (int ci = 0; ci < NC; ci++) { String thisTag = classFeature.valueAtIndex(ci); if (!isValidTag(thisTag, seq.getData(i))) { continue; } if ((i + 1 < seq.length()) || isValidEndTag(thisTag)) { double pInst = model.beta(ci, seq.generateInstance(i, new double[] { ci })); for (State state : beam) { if (isValidTransition(classFeature.valueAtIndex(state.tag), thisTag)) { tempBeam.add(new State(state.probability + //previous probability model.alpha(state.tag, ci) + //transition probability pInst, //probability of the vector given the tag ci, state)); } } } } MinMaxPriorityQueue<State> t = beam; beam = tempBeam; tempBeam = t; tempBeam.clear(); } State max = beam.remove(); double[] prediction = new double[len]; for (int i = len - 1; i >= 0; i--) { prediction[i] = max.tag; max = max.prev; } return prediction; } private static class State implements Comparable<State> { private final double probability; private final int tag; private final State prev; private State(double probability, int tag, State prev) { this.probability = probability; this.tag = tag; this.prev = prev; } @Override public int compareTo(State o) { return -Doubles.compare(this.probability, o.probability); } @Override public String toString() { return "(" + tag + ", " + probability + ")"; } } }//END OF HMMViterbi