edu.stanford.nlp.sentiment.SentimentCostAndGradient.java Source code

Java tutorial

Introduction

Here is the source code for edu.stanford.nlp.sentiment.SentimentCostAndGradient.java

Source

package edu.stanford.nlp.sentiment;

import java.util.List;
import java.util.Map;

import org.ejml.simple.SimpleMatrix;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;

// TODO: get rid of the word Sentiment everywhere
public class SentimentCostAndGradient extends AbstractCachingDiffFunction {

    private static final Redwood.RedwoodChannels log = Redwood.channels(SentimentCostAndGradient.class);

    private final SentimentModel model;
    private final List<Tree> trainingBatch;

    public SentimentCostAndGradient(SentimentModel model, List<Tree> trainingBatch) {
        this.model = model;
        this.trainingBatch = trainingBatch;
    }

    @Override
    public int domainDimension() {
        // TODO: cache this for speed?
        return model.totalParamSize();
    }

    private static double sumError(Tree tree) {
        if (tree.isLeaf()) {
            return 0.0;
        } else if (tree.isPreTerminal()) {
            return RNNCoreAnnotations.getPredictionError(tree);
        } else {
            double error = 0.0;
            for (Tree child : tree.children()) {
                error += sumError(child);
            }
            return RNNCoreAnnotations.getPredictionError(tree) + error;
        }
    }

    /**
     * Returns the index with the highest value in the {@code predictions} matrix.
     * Indexed from 0.
     */
    private static int getPredictedClass(SimpleMatrix predictions) {
        int argmax = 0;
        for (int i = 1; i < predictions.getNumElements(); ++i) {
            if (predictions.get(i) > predictions.get(argmax)) {
                argmax = i;
            }
        }
        return argmax;
    }

    private static class ModelDerivatives {
        // We use TreeMap for each of these so that they stay in a canonical sorted order
        // binaryTD stands for Transform Derivatives (see the SentimentModel)
        public final TwoDimensionalMap<String, String, SimpleMatrix> binaryTD;
        // the derivatives of the tensors for the binary nodes
        // will be empty if we aren't using tensors
        public final TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD;
        // binaryCD stands for Classification Derivatives
        // if we combined classification derivatives, we just use an empty map
        public final TwoDimensionalMap<String, String, SimpleMatrix> binaryCD;

        // unaryCD stands for Classification Derivatives
        public final Map<String, SimpleMatrix> unaryCD;

        // word vector derivatives
        // will be filled on an as-needed basis, as opposed to having all
        // the words with a lot of empty vectors
        public final Map<String, SimpleMatrix> wordVectorD;

        public double error = 0.0;

        public ModelDerivatives(SentimentModel model) {
            binaryTD = initDerivatives(model.binaryTransform);
            binaryTensorTD = (model.op.useTensors) ? initTensorDerivatives(model.binaryTensors)
                    : TwoDimensionalMap.treeMap();
            binaryCD = (!model.op.combineClassification) ? initDerivatives(model.binaryClassification)
                    : TwoDimensionalMap.treeMap();
            unaryCD = initDerivatives(model.unaryClassification);
            // wordVectorD will be filled on an as-needed basis
            wordVectorD = Generics.newTreeMap();
        }

        public void add(ModelDerivatives other) {
            addMatrices(binaryTD, other.binaryTD);
            addTensors(binaryTensorTD, other.binaryTensorTD);
            addMatrices(binaryCD, other.binaryCD);
            addMatrices(unaryCD, other.unaryCD);
            addMatrices(wordVectorD, other.wordVectorD);

            error += other.error;
        }

        /**
         * Add matrices from the second map to the first map, in place.
         */
        public static void addMatrices(TwoDimensionalMap<String, String, SimpleMatrix> first,
                TwoDimensionalMap<String, String, SimpleMatrix> second) {
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : first) {
                if (second.contains(entry.getFirstKey(), entry.getSecondKey())) {
                    first.put(entry.getFirstKey(), entry.getSecondKey(),
                            entry.getValue().plus(second.get(entry.getFirstKey(), entry.getSecondKey())));
                }
            }
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : second) {
                if (!first.contains(entry.getFirstKey(), entry.getSecondKey())) {
                    first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue());
                }
            }
        }

        /**
         * Add tensors from the second map to the first map, in place.
         */
        public static void addTensors(TwoDimensionalMap<String, String, SimpleTensor> first,
                TwoDimensionalMap<String, String, SimpleTensor> second) {
            for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : first) {
                if (second.contains(entry.getFirstKey(), entry.getSecondKey())) {
                    first.put(entry.getFirstKey(), entry.getSecondKey(),
                            entry.getValue().plus(second.get(entry.getFirstKey(), entry.getSecondKey())));
                }
            }
            for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : second) {
                if (!first.contains(entry.getFirstKey(), entry.getSecondKey())) {
                    first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue());
                }
            }
        }

        /**
         * Add matrices from the second map to the first map, in place.
         */
        public static void addMatrices(Map<String, SimpleMatrix> first, Map<String, SimpleMatrix> second) {
            for (Map.Entry<String, SimpleMatrix> entry : first.entrySet()) {
                if (second.containsKey(entry.getKey())) {
                    first.put(entry.getKey(), entry.getValue().plus(second.get(entry.getKey())));
                }
            }
            for (Map.Entry<String, SimpleMatrix> entry : second.entrySet()) {
                if (!first.containsKey(entry.getKey())) {
                    first.put(entry.getKey(), entry.getValue());
                }
            }
        }

        /**
         * Init a TwoDimensionalMap with 0 matrices for all the matrices in the original map.
         */
        private static TwoDimensionalMap<String, String, SimpleMatrix> initDerivatives(
                TwoDimensionalMap<String, String, SimpleMatrix> map) {
            TwoDimensionalMap<String, String, SimpleMatrix> derivatives = TwoDimensionalMap.treeMap();

            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : map) {
                int numRows = entry.getValue().numRows();
                int numCols = entry.getValue().numCols();

                derivatives.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
            }

            return derivatives;
        }

        /**
         * Init a TwoDimensionalMap with 0 tensors for all the tensors in the original map.
         */
        private static TwoDimensionalMap<String, String, SimpleTensor> initTensorDerivatives(
                TwoDimensionalMap<String, String, SimpleTensor> map) {
            TwoDimensionalMap<String, String, SimpleTensor> derivatives = TwoDimensionalMap.treeMap();

            for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : map) {
                int numRows = entry.getValue().numRows();
                int numCols = entry.getValue().numCols();
                int numSlices = entry.getValue().numSlices();

                derivatives.put(entry.getFirstKey(), entry.getSecondKey(),
                        new SimpleTensor(numRows, numCols, numSlices));
            }

            return derivatives;
        }

        /**
         * Init a Map with 0 matrices for all the matrices in the original map.
         */
        private static Map<String, SimpleMatrix> initDerivatives(Map<String, SimpleMatrix> map) {
            Map<String, SimpleMatrix> derivatives = Generics.newTreeMap();

            for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
                int numRows = entry.getValue().numRows();
                int numCols = entry.getValue().numCols();
                derivatives.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
            }

            return derivatives;
        }
    }

    private ModelDerivatives scoreDerivatives(List<Tree> trainingBatch) {
        // "final" makes this as fast as having separate maps declared in this function
        final ModelDerivatives derivatives = new ModelDerivatives(model);

        List<Tree> forwardPropTrees = Generics.newArrayList();
        for (Tree tree : trainingBatch) {
            Tree trainingTree = tree.deepCopy();
            // this will attach the error vectors and the node vectors
            // to each node in the tree
            forwardPropagateTree(trainingTree);
            forwardPropTrees.add(trainingTree);
        }

        for (Tree tree : forwardPropTrees) {
            backpropDerivativesAndError(tree, derivatives.binaryTD, derivatives.binaryCD,
                    derivatives.binaryTensorTD, derivatives.unaryCD, derivatives.wordVectorD);
            derivatives.error += sumError(tree);
        }

        return derivatives;
    }

    class ScoringProcessor implements ThreadsafeProcessor<List<Tree>, ModelDerivatives> {
        @Override
        public ModelDerivatives process(List<Tree> trainingBatch) {
            return scoreDerivatives(trainingBatch);
        }

        @Override
        public ThreadsafeProcessor<List<Tree>, ModelDerivatives> newInstance() {
            // should be threadsafe
            return this;
        }
    }

    @Override
    public void calculate(double[] theta) {
        model.vectorToParams(theta);

        final ModelDerivatives derivatives;
        if (model.op.trainOptions.nThreads == 1) {
            derivatives = scoreDerivatives(trainingBatch);
        } else {
            // TODO: because some addition operations happen in different
            // orders now, this results in slightly different values, which
            // over time add up to significantly different models even when
            // given the same random seed.  Probably not a big deal.
            // To be more specific, for trees T1, T2, T3, ... Tn,
            // when using one thread, we sum the derivatives T1 + T2 ...
            // When using multiple threads, we first sum T1 + ... + Tk,
            // then sum Tk+1 + ... + T2k, etc, for split size k.
            // The splits are then summed in order.
            // This different sum order results in slightly different numbers.
            MulticoreWrapper<List<Tree>, ModelDerivatives> wrapper = new MulticoreWrapper<>(
                    model.op.trainOptions.nThreads, new ScoringProcessor());
            // use wrapper.nThreads in case the number of threads was automatically changed
            for (List<Tree> chunk : CollectionUtils.partitionIntoFolds(trainingBatch, wrapper.nThreads())) {
                wrapper.put(chunk);
            }
            wrapper.join();

            derivatives = new ModelDerivatives(model);
            while (wrapper.peek()) {
                ModelDerivatives batchDerivatives = wrapper.poll();
                derivatives.add(batchDerivatives);
            }
        }

        // scale the error by the number of sentences so that the
        // regularization isn't drowned out for large training batchs
        double scale = (1.0 / trainingBatch.size());
        value = derivatives.error * scale;

        value += scaleAndRegularize(derivatives.binaryTD, model.binaryTransform, scale,
                model.op.trainOptions.regTransformMatrix, false);
        value += scaleAndRegularize(derivatives.binaryCD, model.binaryClassification, scale,
                model.op.trainOptions.regClassification, true);
        value += scaleAndRegularizeTensor(derivatives.binaryTensorTD, model.binaryTensors, scale,
                model.op.trainOptions.regTransformTensor);
        value += scaleAndRegularize(derivatives.unaryCD, model.unaryClassification, scale,
                model.op.trainOptions.regClassification, false, true);
        value += scaleAndRegularize(derivatives.wordVectorD, model.wordVectors, scale,
                model.op.trainOptions.regWordVector, true, false);

        derivative = NeuralUtils.paramsToVector(theta.length, derivatives.binaryTD.valueIterator(),
                derivatives.binaryCD.valueIterator(),
                SimpleTensor.iteratorSimpleMatrix(derivatives.binaryTensorTD.valueIterator()),
                derivatives.unaryCD.values().iterator(), derivatives.wordVectorD.values().iterator());
    }

    private static double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> derivatives,
            TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices, double scale, double regCost,
            boolean dropBiasColumn) {
        double cost = 0.0; // the regularization cost
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
            SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            SimpleMatrix regMatrix = entry.getValue();
            if (dropBiasColumn) {
                regMatrix = new SimpleMatrix(regMatrix);
                regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, new SimpleMatrix(regMatrix.numRows(), 1));
            }
            D = D.scale(scale).plus(regMatrix.scale(regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
            cost += regMatrix.elementMult(regMatrix).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private static double scaleAndRegularize(Map<String, SimpleMatrix> derivatives,
            Map<String, SimpleMatrix> currentMatrices, double scale, double regCost, boolean activeMatricesOnly,
            boolean dropBiasColumn) {
        double cost = 0.0; // the regularization cost
        for (Map.Entry<String, SimpleMatrix> entry : currentMatrices.entrySet()) {
            SimpleMatrix D = derivatives.get(entry.getKey());
            if (activeMatricesOnly && D == null) {
                // Fill in an emptpy matrix so the length of theta can match.
                // TODO: might want to allow for sparse parameter vectors
                derivatives.put(entry.getKey(),
                        new SimpleMatrix(entry.getValue().numRows(), entry.getValue().numCols()));
                continue;
            }
            SimpleMatrix regMatrix = entry.getValue();
            if (dropBiasColumn) {
                regMatrix = new SimpleMatrix(regMatrix);
                regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, new SimpleMatrix(regMatrix.numRows(), 1));
            }
            D = D.scale(scale).plus(regMatrix.scale(regCost));
            derivatives.put(entry.getKey(), D);
            cost += regMatrix.elementMult(regMatrix).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private static double scaleAndRegularizeTensor(TwoDimensionalMap<String, String, SimpleTensor> derivatives,
            TwoDimensionalMap<String, String, SimpleTensor> currentMatrices, double scale, double regCost) {
        double cost = 0.0; // the regularization cost
        for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : currentMatrices) {
            SimpleTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = D.scale(scale).plus(entry.getValue().scale(regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
            cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
        }
        return cost;
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
            TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
            TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD, Map<String, SimpleMatrix> unaryCD,
            Map<String, SimpleMatrix> wordVectorD) {
        SimpleMatrix delta = new SimpleMatrix(model.op.numHid, 1);
        backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, delta);
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
            TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
            TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD, Map<String, SimpleMatrix> unaryCD,
            Map<String, SimpleMatrix> wordVectorD, SimpleMatrix deltaUp) {
        if (tree.isLeaf()) {
            return;
        }

        SimpleMatrix currentVector = RNNCoreAnnotations.getNodeVector(tree);
        String category = tree.label().value();
        category = model.basicCategory(category);

        // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
        SimpleMatrix goldLabel = new SimpleMatrix(model.numClasses, 1);
        int goldClass = RNNCoreAnnotations.getGoldClass(tree);
        if (goldClass >= 0) {
            goldLabel.set(goldClass, 1.0);
        }

        double nodeWeight = model.op.trainOptions.getClassWeight(goldClass);

        SimpleMatrix predictions = RNNCoreAnnotations.getPredictions(tree);

        // If this is an unlabeled class, set deltaClass to 0.  We could
        // make this more efficient by eliminating various of the below
        // calculations, but this would be the easiest way to handle the
        // unlabeled class
        SimpleMatrix deltaClass = goldClass >= 0 ? predictions.minus(goldLabel).scale(nodeWeight)
                : new SimpleMatrix(predictions.numRows(), predictions.numCols());
        SimpleMatrix localCD = deltaClass.mult(NeuralUtils.concatenateWithBias(currentVector).transpose());

        double error = -(NeuralUtils.elementwiseApplyLog(predictions).elementMult(goldLabel).elementSum());
        error = error * nodeWeight;
        RNNCoreAnnotations.setPredictionError(tree, error);

        if (tree.isPreTerminal()) { // below us is a word vector
            unaryCD.put(category, unaryCD.get(category).plus(localCD));

            String word = tree.children()[0].label().value();
            word = model.getVocabWord(word);

            //SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            //SimpleMatrix deltaFromClass = model.getUnaryClassification(category).transpose().mult(deltaClass);
            //SimpleMatrix deltaFull = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).plus(deltaUp);
            //SimpleMatrix wordDerivative = deltaFull.elementMult(currentVectorDerivative);
            //wordVectorD.put(word, wordVectorD.get(word).plus(wordDerivative));

            SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            SimpleMatrix deltaFromClass = model.getUnaryClassification(category).transpose().mult(deltaClass);
            deltaFromClass = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1)
                    .elementMult(currentVectorDerivative);
            SimpleMatrix deltaFull = deltaFromClass.plus(deltaUp);
            SimpleMatrix oldWordVectorD = wordVectorD.get(word);
            if (oldWordVectorD == null) {
                wordVectorD.put(word, deltaFull);
            } else {
                wordVectorD.put(word, oldWordVectorD.plus(deltaFull));
            }
        } else {
            // Otherwise, this must be a binary node
            String leftCategory = model.basicCategory(tree.children()[0].label().value());
            String rightCategory = model.basicCategory(tree.children()[1].label().value());
            if (model.op.combineClassification) {
                unaryCD.put("", unaryCD.get("").plus(localCD));
            } else {
                binaryCD.put(leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).plus(localCD));
            }

            SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
            SimpleMatrix deltaFromClass = model.getBinaryClassification(leftCategory, rightCategory).transpose()
                    .mult(deltaClass);
            deltaFromClass = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1)
                    .elementMult(currentVectorDerivative);
            SimpleMatrix deltaFull = deltaFromClass.plus(deltaUp);

            SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
            SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
            SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
            SimpleMatrix W_df = deltaFull.mult(childrenVector.transpose());
            binaryTD.put(leftCategory, rightCategory, binaryTD.get(leftCategory, rightCategory).plus(W_df));
            SimpleMatrix deltaDown;
            if (model.op.useTensors) {
                SimpleTensor Wt_df = getTensorGradient(deltaFull, leftVector, rightVector);
                binaryTensorTD.put(leftCategory, rightCategory,
                        binaryTensorTD.get(leftCategory, rightCategory).plus(Wt_df));
                deltaDown = computeTensorDeltaDown(deltaFull, leftVector, rightVector,
                        model.getBinaryTransform(leftCategory, rightCategory),
                        model.getBinaryTensor(leftCategory, rightCategory));
            } else {
                deltaDown = model.getBinaryTransform(leftCategory, rightCategory).transpose().mult(deltaFull);
            }

            SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
            SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
            SimpleMatrix leftDeltaDown = deltaDown.extractMatrix(0, deltaFull.numRows(), 0, 1);
            SimpleMatrix rightDeltaDown = deltaDown.extractMatrix(deltaFull.numRows(), deltaFull.numRows() * 2, 0,
                    1);
            backpropDerivativesAndError(tree.children()[0], binaryTD, binaryCD, binaryTensorTD, unaryCD,
                    wordVectorD, leftDerivative.elementMult(leftDeltaDown));
            backpropDerivativesAndError(tree.children()[1], binaryTD, binaryCD, binaryTensorTD, unaryCD,
                    wordVectorD, rightDerivative.elementMult(rightDeltaDown));
        }
    }

    private static SimpleMatrix computeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector,
            SimpleMatrix rightVector, SimpleMatrix W, SimpleTensor Wt) {
        SimpleMatrix WTDelta = W.transpose().mult(deltaFull);
        SimpleMatrix WTDeltaNoBias = WTDelta.extractMatrix(0, deltaFull.numRows() * 2, 0, 1);
        int size = deltaFull.getNumElements();
        SimpleMatrix deltaTensor = new SimpleMatrix(size * 2, 1);
        SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
        for (int slice = 0; slice < size; ++slice) {
            SimpleMatrix scaledFullVector = fullVector.scale(deltaFull.get(slice));
            deltaTensor = deltaTensor
                    .plus(Wt.getSlice(slice).plus(Wt.getSlice(slice).transpose()).mult(scaledFullVector));
        }
        return deltaTensor.plus(WTDeltaNoBias);
    }

    private static SimpleTensor getTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector,
            SimpleMatrix rightVector) {
        int size = deltaFull.getNumElements();
        SimpleTensor Wt_df = new SimpleTensor(size * 2, size * 2, size);
        // TODO: combine this concatenation with computeTensorDeltaDown?
        SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
        for (int slice = 0; slice < size; ++slice) {
            Wt_df.setSlice(slice, fullVector.scale(deltaFull.get(slice)).mult(fullVector.transpose()));
        }
        return Wt_df;
    }

    /**
     * This is the method to call for assigning labels and node vectors
     * to the Tree.  After calling this, each of the non-leaf nodes will
     * have the node vector and the predictions of their classes
     * assigned to that subtree's node.  The annotations filled in are
     * the RNNCoreAnnotations.NodeVector, Predictions, and
     * PredictedClass.  In general, PredictedClass will be the most
     * useful annotation except when training.
     */
    public void forwardPropagateTree(Tree tree) {
        SimpleMatrix nodeVector; // initialized below or Exception thrown // = null;
        SimpleMatrix classification; // initialized below or Exception thrown // = null;

        if (tree.isLeaf()) {
            // We do nothing for the leaves.  The preterminals will
            // calculate the classification for this word/tag.  In fact, the
            // recursion should not have gotten here (unless there are
            // degenerate trees of just one leaf)
            log.info("SentimentCostAndGradient: warning: We reached leaves in forwardPropagate: " + tree);
            throw new AssertionError("We should not have reached leaves in forwardPropagate");
        } else if (tree.isPreTerminal()) {
            classification = model.getUnaryClassification(tree.label().value());
            String word = tree.children()[0].label().value();
            SimpleMatrix wordVector = model.getWordVector(word);
            nodeVector = NeuralUtils.elementwiseApplyTanh(wordVector);
        } else if (tree.children().length == 1) {
            log.info("SentimentCostAndGradient: warning: Non-preterminal nodes of size 1: " + tree);
            throw new AssertionError("Non-preterminal nodes of size 1 should have already been collapsed");
        } else if (tree.children().length == 2) {
            forwardPropagateTree(tree.children()[0]);
            forwardPropagateTree(tree.children()[1]);

            String leftCategory = tree.children()[0].label().value();
            String rightCategory = tree.children()[1].label().value();
            SimpleMatrix W = model.getBinaryTransform(leftCategory, rightCategory);
            classification = model.getBinaryClassification(leftCategory, rightCategory);

            SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
            SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
            SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
            if (model.op.useTensors) {
                SimpleTensor tensor = model.getBinaryTensor(leftCategory, rightCategory);
                SimpleMatrix tensorIn = NeuralUtils.concatenate(leftVector, rightVector);
                SimpleMatrix tensorOut = tensor.bilinearProducts(tensorIn);
                nodeVector = NeuralUtils.elementwiseApplyTanh(W.mult(childrenVector).plus(tensorOut));
            } else {
                nodeVector = NeuralUtils.elementwiseApplyTanh(W.mult(childrenVector));
            }
        } else {
            log.info("SentimentCostAndGradient: warning: Tree not correctly binarized: " + tree);
            throw new AssertionError("Tree not correctly binarized");
        }

        SimpleMatrix predictions = NeuralUtils
                .softmax(classification.mult(NeuralUtils.concatenateWithBias(nodeVector)));

        int index = getPredictedClass(predictions);
        if (!(tree.label() instanceof CoreLabel)) {
            log.info("SentimentCostAndGradient: warning: No CoreLabels in nodes: " + tree);
            throw new AssertionError("Expected CoreLabels in the nodes");
        }
        CoreLabel label = (CoreLabel) tree.label();
        label.set(RNNCoreAnnotations.Predictions.class, predictions);
        label.set(RNNCoreAnnotations.PredictedClass.class, index);
        label.set(RNNCoreAnnotations.NodeVector.class, nodeVector);
    } // end forwardPropagateTree

}