Example usage for org.apache.commons.math3.distribution.fitting MultivariateNormalMixtureExpectationMaximization getFittedModel

List of usage examples for org.apache.commons.math3.distribution.fitting MultivariateNormalMixtureExpectationMaximization getFittedModel

Introduction

In this page you can find the example usage for org.apache.commons.math3.distribution.fitting MultivariateNormalMixtureExpectationMaximization getFittedModel.

Prototype

public MixtureMultivariateNormalDistribution getFittedModel() 

Source Link

Document

Gets the fitted model.

Usage

From source file:sly.speakrecognizer.test.math.MultivariateNormalMixtureExpectationMaximizationFitterTest.java

@Test
public void testFit() {
    // Test that the loglikelihood, weights, and models are determined and
    // fitted correctly
    double[][] data = getTestSamples();
    double correctLogLikelihood = -4.292431006791994;
    double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
    MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
    correctMVNs[0] = new MultivariateNormalDistribution(
            new double[] { -1.4213112715121132, 1.6924690505757753 }, new double[][] {
                    { 1.739356907285747, -0.5867644251487614 }, { -0.5867644251487614, 1.0232932029324642 } });

    correctMVNs[1] = new MultivariateNormalDistribution(new double[] { 4.213612224374709, 7.975621325853645 },
            new double[][] { { 4.245384898007161, 2.5797798966382155 },
                    { 2.5797798966382155, 3.9200272522448367 } });
    //=========================================
    MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(
            data);/*from w  w  w. j av a2s  .  co m*/
    MixtureMultivariateNormalDistribution initialMix = MultivariateNormalMixtureExpectationMaximization
            .estimate(data, 2);
    fitter.fit(initialMix);
    MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
    printMMND(fittedMix);
    List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();

    Assert.assertEquals(correctLogLikelihood, fitter.getLogLikelihood(), Math.ulp(1d));

    int i = 0;
    for (Pair<Double, MultivariateNormalDistribution> component : components) {
        double weight = component.getFirst();
        MultivariateNormalDistribution mvn = component.getSecond();
        Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
        assertMultivariateNormalDistribution(correctMVNs[i], mvn, 0);
        i++;
    }
}

From source file:sly.speakrecognizer.test.math.MultivariateNormalMixtureExpectationMaximizationFitterTest.java

@Test
public void testFitForDifferentSizesData() {
    System.out.println("TEST DLA ROZNYCH ROZMAROW DANYCH");
    double[][][] covariances = { new double[][] { new double[] { 1.74, -0.59 }, new double[] { -0.59, 1.02 }, },
            new double[][] { new double[] { 4.24, 2.58 }, new double[] { 2.58, 3.92 }, }, };
    double[][] means = { new double[] { -1.42, 1.69 }, new double[] { 4.21, 7.98 }, };

    double[] weights = new double[] { 0.296, 0.704 };
    MixtureMultivariateNormalDistribution mmnd = new MixtureMultivariateNormalDistribution(weights, means,
            covariances);/*  www  .  j a va2s.  c  om*/

    int[] lengths = new int[] { 10, 50, 250, 1500, 10000, 100000 };
    double[] errors = new double[lengths.length];
    int counter = 0;
    for (int length : lengths) {
        System.out.println("Dla dugoci danych: " + length);
        double[][] data = getTestSamples(length, mmnd);
        MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(
                data);
        MixtureMultivariateNormalDistribution initialMix = MultivariateNormalMixtureExpectationMaximization
                .estimate(data, 2);
        fitter.fit(initialMix);
        MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
        printMMND(fittedMix);
        errors[counter++] = printErrorsMMND(mmnd, fittedMix);
    }
    System.out.println("Podsumowanie");
    for (int x = 0; x < lengths.length; x++) {
        System.out.println("Dla length = " + lengths[x] + " bd wynosi = " + errors[x]);
    }

}