com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomComplexGaussianPolynomial.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomComplexGaussianPolynomial.java

Source

/*******************************************************************************
*   Copyright 2012 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.solvers.sumproduct.customFactors;

import static com.analog.lyric.math.MoreMatrixUtils.*;
import static org.apache.commons.math3.linear.MatrixUtils.*;

import java.util.List;

import org.apache.commons.math3.complex.Complex;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;

public class CustomComplexGaussianPolynomial extends MultivariateGaussianFactorBase {
    private double[] _powers;
    private Complex[] _coeffs;
    private int _iterations;

    /*
     * TODO: changes
     * expect complex x and y with covariance matrices
     * expect array of complex variables for coefficients.
     * modify newton raphson to deal with complex coefficients
     * modify the code that calculates means and covariance
     */
    public CustomComplexGaussianPolynomial(Factor factor, SumProductSolverGraph parent) {
        super(factor, parent);

        if (factor.getSiblingCount() != 2)
            throw new DimpleException("expected two complex numbers");

        final List<Value> constants = factor.getConstantValues();

        //TODO: error check
        double[] powers;
        double[] rcoeffs;
        double[] icoeffs;

        powers = constants.get(0).getDoubleArray();
        rcoeffs = constants.get(1).getDoubleArray();

        if (constants.size() > 2)
            icoeffs = constants.get(2).getDoubleArray();
        else
            icoeffs = new double[rcoeffs.length];

        _coeffs = new Complex[rcoeffs.length];

        for (int i = 0; i < _coeffs.length; i++)
            _coeffs[i] = new Complex(rcoeffs[i], icoeffs[i]);

        //TODO: error check pairs are in fact pairs.

        //TODO: only allow powers of 1, 3, 5, etc...

        _powers = powers;
        _iterations = 3;

        //double [] coeffs = (double[])constants[0];
    }

    public void setNumIterations(int num) {
        _iterations = num;
    }

    @Override
    protected void doUpdate() {
        updateToX();
        updateToY();
    }

    @Override
    public void doUpdateEdge(int outPortNum) {
        //TODO: somehow avoid double computation.
        //Maybe this is avoided when we have multivariate gaussian

        if (outPortNum == 0)
            updateToY();
        else
            updateToX();
    }

    public void updateToX() {

        //get mu and sigma for complex numbers
        MultivariateNormalParameters y = getSiblingEdgeState(0).varToFactorMsg;
        MultivariateNormalParameters x = getSiblingEdgeState(1).factorToVarMsg;

        if (y.isNull()) {
            x.setNull();
            return;
        }

        Complex means = new Complex(y.getMean()[0], y.getMean()[1]);

        //get samples
        Complex[] samples = getSamples(means, y.getCovariance());
        Complex[] results = new Complex[samples.length];

        for (int i = 0; i < samples.length; i++) {
            results[i] = newtonRaphson(samples[i], _iterations, _powers, _coeffs);
        }

        Object[] sums = calculateWeightedSums(results);

        means = (Complex) sums[0];
        double[][] covar = (double[][]) sums[1];

        x.setMeanAndCovariance(new double[] { means.getReal(), means.getImaginary() }, covar);

    }

    public void updateToY() {
        //get mu and sigma for complex numbers
        MultivariateNormalParameters y = getSiblingEdgeState(0).factorToVarMsg;
        MultivariateNormalParameters x = getSiblingEdgeState(1).varToFactorMsg;

        if (x.isNull()) {
            y.setNull();
            return;
        }

        double[] xmeans = x.getMean();
        double[][] xcovar = x.getCovariance();

        //get samples
        Complex[] samples = getSamples(new Complex(xmeans[0], xmeans[1]), xcovar);
        Complex[] results = new Complex[samples.length];

        //for each sample
        for (int i = 0; i < samples.length; i++) {
            results[i] = P(samples[i], _powers, _coeffs);

        }

        Object[] sums = calculateWeightedSums(results);
        Complex means = (Complex) sums[0];
        double[][] covar = (double[][]) sums[1];

        y.setMeanAndCovariance(new double[] { means.getReal(), means.getImaginary() }, covar);
    }

    private Complex[] getSamples(Complex mean, double[][] covar) {
        double[] mean_array = new double[2];
        mean_array[0] = mean.getReal();
        mean_array[1] = mean.getImaginary();
        double[][] samples = new double[covar.length * 2][];

        CholeskyDecomposition cd = new CholeskyDecomposition(wrapRealMatrix(covar));

        double[][] chol = matrixGetDataRef(cd.getLT());

        for (int i = 0; i < chol.length; i++) {
            samples[i * 2] = chol[i].clone();
            samples[i * 2 + 1] = chol[i].clone();

            for (int j = 0; j < samples[i * 2].length; j++) {
                samples[i * 2][j] = samples[i * 2][j] / Math.sqrt(covar.length) + mean_array[j];
                samples[i * 2 + 1][j] = samples[i * 2 + 1][j] / -Math.sqrt(covar.length) + mean_array[j];
            }
        }

        Complex[] retval = new Complex[samples.length];
        for (int i = 0; i < retval.length; i++)
            retval[i] = new Complex(samples[i][0], samples[i][1]);

        return retval;
    }

    private Object[] calculateWeightedSums(Complex[] in) {

        double[] means = new double[] { 0, 0 };

        for (int i = 0; i < in.length; i++) {
            means[0] += in[i].getReal();
            means[1] += in[i].getImaginary();
        }
        means[0] /= in.length;
        means[1] /= in.length;

        RealMatrix cm = createRealMatrix(2, 2);
        RealVector mm = wrapRealVector(means);

        for (int i = 0; i < in.length; i++) {
            RealVector m = createRealVector(new double[] { in[i].getReal(), in[i].getImaginary() });
            RealVector m2 = m.subtract(mm);
            cm = cm.add(m2.outerProduct(m2));
        }

        cm = cm.scalarMultiply(1.0 / in.length);

        return new Object[] { new Complex(means[0], means[1]), matrixGetDataRef(cm) };

    }

    private static Complex P(Complex input, double[] powers, Complex[] coeffs) {

        double a = input.getReal();
        double b = input.getImaginary();

        double a2pb2 = a * a + b * b;

        Complex retval = new Complex(0, 0);

        for (int i = 0; i < powers.length; i++) {

            Complex tmp = coeffs[i].multiply(input);
            double tmp2 = Math.pow(a2pb2, powers[i]);
            tmp = tmp.multiply(new Complex(tmp2, 0));

            retval = retval.add(tmp);
        }

        return retval;
    }

    public static double derivativeOfP1OverA(double a, double b, double[] powers, double[] coeffs) {
        double sum = 0;

        double a2pb2 = a * a + b * b;

        for (int index = 0; index < powers.length; index++) {
            int i = (int) powers[index];
            double c_i = coeffs[index];

            sum += c_i * Math.pow(a2pb2, i);
            if (i > 0)
                sum += i * c_i * Math.pow(a2pb2, i - 1) * 2 * a * a;
        }

        if (Double.isNaN(sum))
            throw new DimpleException("derivativeOfP1OverA generated NaN");

        return sum;
    }

    private static double sharedDerivative(double a, double b, double[] powers, double[] coeffs) {
        //sum i >=1 i *c_i*(a^2+b^2)*2ab
        double sum = 0;
        double a2pb2 = a * a + b * b;

        for (int index = 0; index < powers.length; index++) {
            double i = powers[index];
            double c_i = coeffs[index];

            if (i > 0)
                sum += i * c_i * Math.pow(a2pb2, i - 1) * 2 * a * b;
        }

        return sum;
    }

    public static double derivativeOfP1OverB(double a, double b, double[] powers, double[] coeffs) {
        double tmp = sharedDerivative(a, b, powers, coeffs);
        if (Double.isNaN(tmp))
            throw new DimpleException("derivativeOfP1OverB generated NaN");
        return tmp;
    }

    public static double derivativeOfP2OverA(double a, double b, double[] powers, double[] coeffs) {
        double tmp = sharedDerivative(a, b, powers, coeffs);
        if (Double.isNaN(tmp))
            throw new DimpleException("derivativeOfP2OverA generated NaN");
        return tmp;
    }

    public static double derivativeOfP2OverB(double a, double b, double[] powers, double[] coeffs) {
        double sum = 0;

        double a2pb2 = a * a + b * b;

        for (int index = 0; index < powers.length; index++) {
            int i = (int) powers[index];
            double c_i = coeffs[index];

            sum += c_i * Math.pow(a2pb2, i);
            if (i > 0)
                sum += i * c_i * Math.pow(a2pb2, i - 1) * 2 * b * b;
        }

        if (Double.isNaN(sum))
            throw new DimpleException("derivativeOfP2OverB generated NaN");

        return sum;
    }

    public static double[][] invertMatrix(double[][] M) {
        double a = M[0][0];
        double b = M[0][1];
        double c = M[1][0];
        double d = M[1][1];

        double constant = 1 / (a * d - b * c);

        double[][] inverse = new double[][] { new double[] { d * constant, -b * constant },
                new double[] { -c * constant, a * constant } };

        return inverse;

    }

    public static double[] matrixMultiply(double[][] M, double[] input) {
        double[] output = new double[input.length];

        for (int row = 0; row < input.length; row++) {
            double sum = 0;
            for (int col = 0; col < input.length; col++) {
                sum += input[col] * M[row][col];
            }
            output[row] = sum;
        }

        return output;
    }

    public static double[] vectorMultiply(double[] vec, double scalar) {
        double[] retval = vec.clone();

        for (int i = 0; i < retval.length; i++)
            retval[i] *= scalar;

        return retval;
    }

    public static double[] addScalar(double[] a, double b) {
        double[] retval = new double[a.length];
        for (int i = 0; i < a.length; i++) {
            retval[i] = a[i] + b;
        }
        return retval;
    }

    public static double[] addVectors(double[] a, double[] b) {
        double[] retval = new double[a.length];
        for (int i = 0; i < a.length; i++)
            retval[i] = a[i] + b[i];

        return retval;
    }

    public static double[][] buildJacobian(Complex[] coeffs, double[] powers, Complex x) {
        double a = x.getReal();
        double b = x.getImaginary();
        double a2pb2 = a * a + b * b;

        Complex dyda = new Complex(0, 0);
        Complex dydb = new Complex(0, 0);

        for (int k = 0; k < coeffs.length; k++) {
            double pow = powers[k];

            if (pow == 0) {
                dyda = dyda.add(coeffs[k]);
                dydb = dydb.add(coeffs[k].multiply(new Complex(0, 1)));
            } else {
                dyda = dyda.add(coeffs[k].multiply(x.multiply(new Complex(2 * a * k * Math.pow(a2pb2, k - 1), 0))
                        .add(new Complex(Math.pow(a2pb2, k), 0))));
                dydb = dydb.add(coeffs[k].multiply(x.multiply(new Complex(2 * b * k * Math.pow(a2pb2, k - 1), 0))
                        .add(new Complex(0, Math.pow(a2pb2, k)))));

            }
        }

        return new double[][] { new double[] { dyda.getReal(), dydb.getReal() },
                new double[] { dyda.getImaginary(), dydb.getImaginary() } };
    }

    public static Complex newtonRaphson(Complex input, int numIterations, double[] powers, Complex[] coeffs) {
        //for some number of iterations
        //initialize x to y;
        Complex output = new Complex(input.getReal(), input.getImaginary());

        for (int i = 0; i < numIterations; i++) {
            double[][] J = buildJacobian(coeffs, powers, output);

            Complex pout = P(output, powers, coeffs);
            Complex y = input.subtract(pout);

            double[][] Jinv = invertMatrix(J);
            double[] tmp = matrixMultiply(Jinv, new double[] { y.getReal(), y.getImaginary() });

            output = output.add(new Complex(tmp[0], tmp[1]));
        }

        return output;
    }

}