com.analog.lyric.dimple.data.DataStack.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.data.DataStack.java

Source

/*******************************************************************************
*   Copyright 2015 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.data;

import static java.lang.String.*;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;

import org.eclipse.jdt.annotation.Nullable;

import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.FactorGraphIterables;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.IVariableToValue;
import com.analog.lyric.dimple.model.variables.Variable;
import com.google.common.collect.Lists;

/**
 * 
 * @since 0.08
 * @author Christopher Barber
 */
public class DataStack extends AbstractList<DataLayer<?>> implements IVariableToValue {
    /*-------
     * State
     */

    private final ArrayList<DataLayer<?>> _stack;

    /*--------------
     * Construction
     */

    public DataStack(Collection<DataLayer<?>> layers) {
        if (layers.size() == 0)
            throw new IllegalArgumentException(
                    format("Cannot create %s with no layers.", getClass().getSimpleName()));

        _stack = new ArrayList<>(layers);

        // Ensure that all layers are for the same graph
        final FactorGraph root = _stack.get(0).rootGraph();
        for (int i = 1, n = _stack.size(); i < n; ++i) {
            if (_stack.get(i).rootGraph() != root) {
                throw new IllegalArgumentException(
                        format("Cannot create %s with layers from different graphs", getClass().getSimpleName()));
            }
        }
    }

    public DataStack(DataLayer<?> firstLayer, DataLayer<?>... additionalLayers) {
        this(Lists.asList(firstLayer, additionalLayers));
    }

    /*--------------
     * List methods
     */

    @Override
    public DataLayer<?> get(int index) {
        return _stack.get(index);
    }

    @Override
    public int size() {
        return _stack.size();
    }

    /*------------------
     * IVariableToValue
     */

    @Override
    @Nullable
    public Value varToValue(Variable var) {
        return getValue(var);
    }

    /*-------------------
     * DataStack methods
     */

    /**
     * Computes the total energy for the graph tree represented by this data stack.
     * <p>
     * Computes the total energy by adding the energy evaluated for all the factors
     * and variable priors and conditioning functions given the value specified for
     * each variable in the data stack. Specifically:
     * <ul>
     * <li><b>for each factor</b>: {@linkplain #getValue(Variable) looks up the value} for each of the variables
     * connected to the factor and passes them the {@linkplain FactorFunction#evalEnergy(Value[])
     * evalEnergy} method of the factor's {@linkplain Factor#getFactorFunction() factor function}.
     * 
     * <li><b>for each variable</b>: {@linkplain #getValue(Variable) looks up the value} for the variable and
     * passes it to the {@linkplain IUnaryFactorFunction#evalEnergy(Value) evalEnergy} method of each
     * {@link IUnaryFactorFunction} specified for that variable in layers that precede the layer containing
     * the variable value.
     * </ul>
     * <p>
     * @since 0.08
     * @throws IllegalStateException if any variable in the graph lacks a value.
     */
    public double computeTotalEnergy() {
        final FactorGraph root = rootGraph();
        final int nLayers = _stack.size();

        double energy = 0.0;

        final IUnaryFactorFunction[] functions = new IUnaryFactorFunction[nLayers];
        for (Variable var : FactorGraphIterables.variables(root)) {
            for (int i = 0; i < nLayers; ++i) {
                IDatum datum = _stack.get(i).get(var);
                if (datum instanceof Value) {
                    Value value = (Value) datum;
                    if (!var.getDomain().valueInDomain(value)) {
                        return Double.POSITIVE_INFINITY;
                    }
                    while (--i >= 0) {
                        IUnaryFactorFunction function = functions[i];
                        if (function != null) {
                            energy += function.evalEnergy(value);
                            if (energy == Double.POSITIVE_INFINITY) {
                                return energy;
                            }
                        }
                    }
                    break;
                } else if (datum instanceof IUnaryFactorFunction) {
                    functions[i] = (IUnaryFactorFunction) datum;
                } else {
                    functions[i] = null;
                }
            }
        }

        Value[] values = null;
        for (Factor factor : FactorGraphIterables.factors(root)) {
            values = factor.fillInArgumentValues(this, null);
            energy += factor.evalEnergy(values);
            if (energy == Double.POSITIVE_INFINITY) {
                return energy;
            }
        }

        return energy;
    }

    public @Nullable Value getValue(Variable var) {
        for (DataLayer<?> layer : _stack) {
            IDatum datum = layer.get(var);
            if (datum instanceof Value) {
                return (Value) datum;
            }
        }

        return null;
    }

    /**
     * Root graph for all layers in this stack.
     * @since 0.08
     * @see DataLayer#rootGraph()
     */
    public FactorGraph rootGraph() {
        return _stack.get(0).rootGraph();
    }
}