com.analog.lyric.dimple.test.solvers.sumproduct.TestSampledFactors.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.test.solvers.sumproduct.TestSampledFactors.java

Source

/*******************************************************************************
*   Copyright 2015 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.test.solvers.sumproduct;

import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;

import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.random.CorrelatedRandomVectorGenerator;
import org.apache.commons.math3.random.GaussianRandomGenerator;
import org.apache.commons.math3.stat.correlation.StorelessCovariance;
import org.junit.Test;

import com.analog.lyric.dimple.factorfunctions.MultivariateNormal;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel;
import com.analog.lyric.dimple.model.variables.Complex;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.google.common.primitives.Doubles;

/**
 * Tests for sampled factors in SumProduct solver
 */
public class TestSampledFactors extends DimpleTestBase {
    /**
     * Adapted from MATLAB test4 in tests/algoGaussian/testSampledFactors.m
     */
    @Test
    public void sampledComplexProduct() {
        // NOTE: test may fail if seed is changed! We keep the number of samples down so that the test doesn't
        // take too long. Increasing the samples produces better results.

        testRand.setSeed(42);

        try (CurrentModel cur = using(new FactorGraph())) {
            final Complex a = complex("a");
            final Complex b = complex("b");
            final Complex c = product(a, b);

            double[] aMean = new double[] { 10, 10 };
            RealMatrix aCovariance = randCovariance(2);
            a.setPrior(new MultivariateNormal(aMean, aCovariance.getData()));

            double[] bMean = new double[] { -20, 20 };
            RealMatrix bCovariance = randCovariance(2);
            b.setPrior(new MultivariateNormalParameters(bMean, bCovariance.getData()));

            GaussianRandomGenerator normalGenerator = new GaussianRandomGenerator(testRand);
            CorrelatedRandomVectorGenerator aGenerator = new CorrelatedRandomVectorGenerator(aMean, aCovariance,
                    1e-12, normalGenerator);
            CorrelatedRandomVectorGenerator bGenerator = new CorrelatedRandomVectorGenerator(bMean, bCovariance,
                    1e-12, normalGenerator);

            StorelessCovariance expectedCov = new StorelessCovariance(2);

            final int nSamples = 10000;

            RealVector expectedMean = MatrixUtils.createRealVector(new double[2]);
            double[] cSample = new double[2];

            for (int i = 0; i < nSamples; ++i) {
                double[] aSample = aGenerator.nextVector();
                double[] bSample = bGenerator.nextVector();

                // Compute complex product
                cSample[0] = aSample[0] * bSample[0] - aSample[1] * bSample[1];
                cSample[1] = aSample[0] * bSample[1] + aSample[1] * bSample[0];

                expectedMean.addToEntry(0, cSample[0]);
                expectedMean.addToEntry(1, cSample[1]);

                expectedCov.increment(cSample);
            }

            expectedMean.mapDivideToSelf(nSamples); // normalize

            SumProductSolverGraph sfg = requireNonNull(cur.graph.setSolverFactory(new SumProductSolver()));
            sfg.setOption(GibbsOptions.numSamples, nSamples);

            sfg.solve();

            MultivariateNormalParameters cBelief = requireNonNull(c.getBelief());

            RealVector observedMean = MatrixUtils.createRealVector(cBelief.getMean());
            double scaledMeanDistance = expectedMean.getDistance(observedMean) / expectedMean.getNorm();

            //         System.out.format("expectedMean = %s\n", expectedMean);
            //         System.out.format("observedMean = %s\n", observedMean);
            //         System.out.println(scaledMeanDistance);

            assertEquals(0.0, scaledMeanDistance, .02);

            RealMatrix expectedCovariance = expectedCov.getCovarianceMatrix();
            RealMatrix observedCovariance = MatrixUtils.createRealMatrix(cBelief.getCovariance());
            RealMatrix diffCovariance = expectedCovariance.subtract(observedCovariance);

            double scaledCovarianceDistance = diffCovariance.getNorm() / expectedCovariance.getNorm();

            //         System.out.println(expectedCovariance);
            //         System.out.println(expectedCovariance.getNorm());
            //         System.out.println(diffCovariance);
            //         System.out.println(diffCovariance.getNorm());
            //         System.out.println(diffCovariance.getNorm() / expectedCovariance.getNorm());

            assertEquals(0.0, scaledCovarianceDistance, .2);
        }
    }

    /**
     * Generates a random covariance matrix with given dimension.
     */
    RealMatrix randCovariance(int n) {
        RealMatrix A = MatrixUtils.createRealMatrix(n, n);

        // Randomize
        A.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
            @Override
            public double visit(int row, int column, double value) {
                return testRand.nextDouble();
            }
        });

        RealMatrix B = A.add(A.transpose()); // B is symmetric
        double minEig = Doubles.min(new EigenDecomposition(B).getRealEigenvalues());
        double r = testRand.nextGaussian();
        r *= r;
        r += Math.ulp(1.0);
        RealMatrix I = MatrixUtils.createRealIdentityMatrix(n);
        RealMatrix C = B.add(I.scalarMultiply(r - minEig));

        return C;
    }
}