Java tutorial
/* * (c) 2005 David B. Bracewell * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.davidbracewell.ml.sequence.decoder; import com.davidbracewell.conversion.Val; import com.davidbracewell.ml.classification.ClassificationResult; import com.davidbracewell.ml.Feature; import com.davidbracewell.ml.sequence.Sequence; import com.davidbracewell.ml.sequence.SequenceModel; import com.davidbracewell.ml.sequence.linear.LinearSequenceModel; import com.google.common.collect.MinMaxPriorityQueue; import com.google.common.primitives.Doubles; import java.io.Serializable; /** * The type Linear viterbi. * * @author David B. Bracewell */ public class LinearViterbi<V> extends Viterbi<V> { private static final long serialVersionUID = 555426483507549520L; private final int beamSize; /** * Instantiates a new Linear viterbi. * * @param beamSize the beam size */ public LinearViterbi(int beamSize) { this.beamSize = beamSize; } @Override public double[] decode(SequenceModel<V> raw, Sequence<V> sequence) { LinearSequenceModel<V> model = Val.of(raw).cast(); final Feature classFeature = model.getTargetFeature(); final int numStates = classFeature.alphabetSize(); MinMaxPriorityQueue<State> beam = MinMaxPriorityQueue.maximumSize(beamSize).create(); ClassificationResult result = model.classifyItem(0, sequence, new double[0]); for (int ci = 0; ci < numStates; ci++) { if (isValidStartTag(classFeature.valueAtIndex(ci)) && isValidTag(classFeature.valueAtIndex(ci), sequence.getData(0))) { beam.add(new State(Math.log(result.getConfidence(ci)), ci, null, 0)); } } MinMaxPriorityQueue<State> tempBeam = MinMaxPriorityQueue.maximumSize(beamSize).create(); for (int i = 1; i < sequence.length(); i++) { while (!beam.isEmpty()) { // go through all the previous states State state = beam.removeFirst(); String previousTag = classFeature.valueAtIndex(state.tag); result = model.classifyItem(i, sequence, state.labels()); for (int ci = 0; ci < numStates; ci++) { if (isValidTransition(previousTag, classFeature.valueAtIndex(ci)) && ((i + 1 < sequence.length()) || isValidEndTag(classFeature.valueAtIndex(ci))) && isValidTag(classFeature.valueAtIndex(ci), sequence.getData(i))) { tempBeam.add( new State(state.probability + Math.log(result.getConfidence(ci)), ci, state, i)); } } } beam.addAll(tempBeam); tempBeam.clear(); } return beam.remove().labels(); } private static class State implements Comparable<State>, Serializable { private static final long serialVersionUID = -6696246937456087462L; private final double probability; private final int tag; private final State prev; private final int index; private State(double probability, int tag, State prev, int index) { this.probability = probability; this.tag = tag; this.prev = prev; this.index = index; } @Override public int compareTo(State o) { return -Doubles.compare(this.probability, o.probability); } @Override public String toString() { return "(" + tag + ", " + probability + ")"; } /** * Labels double [ ]. * * @return the double [ ] */ public double[] labels() { double[] labels = new double[index + 1]; State s = this; while (s != null) { labels[s.index] = s.tag; s = s.prev; } return labels; } /** * Labels double [ ]. * * @param t0 the t 0 * @return the double [ ] */ public double[] labels(double t0) { double[] labels = new double[index + 2]; labels[labels.length - 1] = t0; State s = this; while (s != null) { labels[s.index] = s.tag; s = s.prev; } return labels; } } }//END OF LinearViterbi