Example usage for org.apache.commons.math3.stat.regression OLSMultipleLinearRegression OLSMultipleLinearRegression

List of usage examples for org.apache.commons.math3.stat.regression OLSMultipleLinearRegression OLSMultipleLinearRegression

Introduction

In this page you can find the example usage for org.apache.commons.math3.stat.regression OLSMultipleLinearRegression OLSMultipleLinearRegression.

Prototype

OLSMultipleLinearRegression

Source Link

Usage

From source file:modelcreation.ModelCreation.java

/**
 * @param args the command line arguments
 *///  w w w .j av a  2  s .  c o  m
public static void main(String[] args) {

    int size = writeDataIntoFile();
    double[][] x = new double[size][2];
    double[] y = new double[size];
    readDataFromFile(x, y);

    //        TTest tTest = new TTest();
    //        System.out.println("p value for home value = " + tTest.tTest(x[0], y));
    //        System.out.println("p value for away value = " + tTest.tTest(x[1], y));
    //        
    System.out.println("Average mean squared error: " + apply10FoldCrossValidation(x, y));

    //        double[] predictions = new double[size];
    //        for (int i = 0; i < size; i++) {             
    //            predictions[i] = 0.5622255342802198 + (1.0682845275289186E-9 * x[i][0]) + (-9.24614306976538E-10 * x[i][1]);
    //                               
    //            //System.out.print("Actual: " + y[i]);
    //            //System.out.println(" Predicted: " + predicted);
    //        }
    //        
    //        System.out.println(calculateMeanSquaredError(y, predictions));

    OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
    regression.newSampleData(y, x);
    regression.setNoIntercept(true);
    printRegressionStatistics(regression);

    //Team[] teams2014 = getTeams(354);
    //Team[] teams2015 = getTeams(398, 2015);

    //Team[] teams = concatTeams(teams2014, teams2015);

    //        HashMap<Integer, ArrayList<Integer>> marketValueGoalsDataset = createMarketValueGoalsDataset(teams2014);
    //
    //        SimpleRegression regression = new SimpleRegression();
    //                
    //        Set<Integer> marketValues = marketValueGoalsDataset.keySet();
    //        for (Integer marketValue:marketValues) {
    //            ArrayList<Integer> goals = marketValueGoalsDataset.get(marketValue);
    //            int totalGoals = 0;
    //            for(Integer goal:goals) {
    //                regression.addData(marketValue, goal);
    //                totalGoals += goal;
    //            }
    //            double avg = (double) totalGoals / goals.size();
    //            System.out.println("Team Value: " + marketValue + ", Goal Average: " + avg);
    //        }      
    //        
    //        System.out.println("Intercept: " + regression.getIntercept());
    //        System.out.println("Slope: " + regression.getSlope());
    //        System.out.println("R^2: " + regression.getRSquare());

    //LinearRegression.calculateLinearRegression(marketValueGoalsDataset);
}

From source file:dase.timeseries.analysis.GrangerTest.java

/**
 * Returns p-value for Granger causality test.
 *
 * @param y//from   w w w. j  a va2 s .co m
 *            - predictable variable
 * @param x
 *            - predictor
 * @param L
 *            - lag, should be 1 or greater.
 * @return p-value of Granger causality
 */
public static double granger(double[] y, double[] x, int L) {
    OLSMultipleLinearRegression h0 = new OLSMultipleLinearRegression();
    OLSMultipleLinearRegression h1 = new OLSMultipleLinearRegression();

    double[][] laggedY = createLaggedSide(L, y);

    double[][] laggedXY = createLaggedSide(L, x, y);

    int n = laggedY.length;

    h0.newSampleData(strip(L, y), laggedY);
    h1.newSampleData(strip(L, y), laggedXY);

    double rs0[] = h0.estimateResiduals();
    double rs1[] = h1.estimateResiduals();

    double RSS0 = sqrSum(rs0);
    double RSS1 = sqrSum(rs1);

    double ftest = ((RSS0 - RSS1) / L) / (RSS1 / (n - 2 * L - 1));

    System.out.println(RSS0 + " " + RSS1);
    System.out.println("F-test " + ftest);

    FDistribution fDist = new FDistribution(L, n - 2 * L - 1);

    double pValue = 1.0 - fDist.cumulativeProbability(ftest);
    System.out.println("P-value " + pValue);
    return pValue;
}

From source file:net.gtl.movieanalytics.model.LinearRegression.java

private double[] estimateParameter(double[][] x, double[] y) {
    //printTestData(x, y);

    OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
    ols.newSampleData(y, x);/*from   ww  w  . ja va2  s  . c  om*/
    return ols.estimateRegressionParameters();
}

From source file:com.insightml.models.regression.OLS.java

@Override
public IModel<Sample, Double> train(final double[][] features, final double[] expected,
        final String[] featureNames) {
    final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
    regression.newSampleData(expected, features);
    return new LinearRegressionModel(regression.estimateRegressionParameters(), featureNames);
}

From source file:com.davidbracewell.ml.regression.LeastSquaresLearner.java

@Override
protected void trainAll(List<Instance> trainingData) {
    OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
    double[] y = new double[trainingData.size()];
    double[][] x = new double[trainingData.size()][];
    int i = 0;/*from  w  w w  .j a va2s. c  om*/
    for (Instance datum : trainingData) {
        y[i] = datum.getTargetValue();
        x[i] = datum.toArray();
        i++;
    }
    regression.newSampleData(y, x);
    double[] params = regression.estimateRegressionParameters();
    model.bias = params[0];
    double[] weights = new double[params.length - 1];
    System.arraycopy(params, 1, weights, 0, params.length - 1);
    model.weights = new DenseVector(weights);
}

From source file:lu.lippmann.cdb.lab.regression.Regression.java

/**
 * Constructor./*  w  ww .  j a  va 2  s. c o m*/
 */
public Regression(final Instances ds, final int idx) throws Exception {
    this.newds = WekaDataProcessingUtil.buildDataSetSortedByAttribute(ds, idx);

    //System.out.println("Regression -> "+newds.toSummaryString());

    final int N = this.newds.numInstances();
    final int M = this.newds.numAttributes();

    final double[][] x = new double[N][M - 1];
    final double[] y = new double[N];
    for (int i = 0; i < N; i++) {
        y[i] = this.newds.instance(i).value(0);
    }
    for (int i = 0; i < N; i++) {
        for (int j = 1; j < M; j++) {
            x[i][j - 1] = this.newds.instance(i).value(j);
        }
    }

    final OLSMultipleLinearRegression reg = new OLSMultipleLinearRegression();
    //reg.setNoIntercept(true);
    reg.newSampleData(y, x);

    this.r2 = reg.calculateRSquared();
    //this.r2=-1d;

    this.coe = reg.estimateRegressionParameters();

    this.estims = calculateEstimations(x, y, coe);
}

From source file:com.mebigfatguy.damus.main.DamusCalculator.java

private boolean calcLinearRegression() {
    try {/*from   w  w  w  .j a  v a2  s .  c om*/
        Context context = Context.instance();
        PredictionModel model = context.getPredictionModel();
        TrainingData data = context.getTrainingData();

        int numMetrics = model.getNumMetrics();
        int numResults = model.getNumResults();
        int numItems = data.getNumItems();

        OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();

        double[][] xx = new double[numItems][];
        for (int i = 0; i < numItems; i++) {
            double[] x = new double[numMetrics];
            for (int m = 0; m < numMetrics; m++) {
                Metric metric = model.getMetric(m);
                double value;
                switch (metric.getType()) {
                case Percent:
                    value = ((Number) data.getItem(i).getValue(metric)).doubleValue();
                    break;

                case Real:
                    value = ((BigDecimal) data.getItem(i).getValue(metric)).doubleValue();
                    break;

                case YesNo:
                    value = ((Boolean) data.getItem(i).getValue(metric)).booleanValue() ? 1.0 : 0.0;
                    break;

                default:
                    value = 0.0;
                    break;
                }
                x[m] = value;
            }
            xx[i] = x;
        }

        for (int r = 0; r < numResults; r++) {
            Metric metric = model.getResult(r);
            double[] y = new double[numItems];
            for (int i = 0; i < numItems; i++) {
                double value;
                switch (metric.getType()) {
                case Percent:
                    value = ((Number) data.getItem(i).getValue(metric)).doubleValue();
                    break;

                case Real:
                    value = ((BigDecimal) data.getItem(i).getValue(metric)).doubleValue();
                    break;

                case YesNo:
                    value = ((Boolean) data.getItem(i).getValue(metric)).booleanValue() ? 1.0 : 0.0;
                    break;

                default:
                    value = 0.0;
                    break;
                }
                y[i] = value;
            }

            regression.newSampleData(y, xx);
            double result = regression.estimateRegressandVariance();
            trainingItem.setValue(metric, new BigDecimal(result));
        }
        return true;
    } catch (Exception e) {
        return false;
    }
}

From source file:msi.gama.util.GamaRegression.java

public GamaRegression(final IScope scope, final GamaFloatMatrix data, final String method) throws Exception {
    AbstractMultipleLinearRegression regressionMethod = null;
    if (method.equals("GLS"))
        regressionMethod = new GLSMultipleLinearRegression();
    else/*  ww w  .  jav  a  2 s.c  o m*/
        regressionMethod = new OLSMultipleLinearRegression();
    final int nbFeatures = data.numCols - 1;
    final int nbInstances = data.numRows;
    final double[] instances = new double[data.numCols * data.numRows];
    for (int i = 0; i < data.length(scope); i++) {
        instances[i] = data.getMatrix()[i];
    }
    regressionMethod.newSampleData(instances, nbInstances, nbFeatures);
    param = regressionMethod.estimateRegressionParameters();
}

From source file:modelcreation.ModelCreation.java

public static double apply10FoldCrossValidation(double[][] x, double[] y) {
    int subSize = y.length / 10;
    ArrayList<Integer> indeces = new ArrayList();
    for (int i = 0; i < y.length; i++) {
        indeces.add(i);/*from  w w w  .  j  a  v  a 2 s .co m*/
    }
    Collections.shuffle(indeces);

    double[] meanSquaredErrors = new double[10];
    int count = 0;
    for (int i = 0; i < 10; i++) {
        System.out.println("-------------Fold " + i + "--------------");
        double[][] subXTest = new double[subSize][2];
        double[] subYTest = new double[subSize];
        double[][] subXTraining = new double[y.length - subSize][2];
        double[] subYTraining = new double[y.length - subSize];

        for (int j = 0; j < i * subSize; j++) {
            int index = indeces.get(count);
            count++;
            subXTraining[j][0] = x[index][0];
            subXTraining[j][1] = x[index][1];
            subYTraining[j] = y[index];
        }

        for (int j = 0; j < subSize; j++) {
            int index = indeces.get(count);
            count++;
            subXTest[j][0] = x[index][0];
            subXTest[j][1] = x[index][1];
            subYTest[j] = y[index];
        }

        for (int j = i * subSize; j < y.length - subSize; j++) {
            int index = indeces.get(count);
            count++;
            subXTraining[j][0] = x[index][0];
            subXTraining[j][1] = x[index][1];
            subYTraining[j] = y[index];
        }

        count = 0;
        OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
        regression.newSampleData(subYTraining, subXTraining);
        regression.setNoIntercept(true);
        meanSquaredErrors[i] = evaluateModel(regression, subXTest, subYTest);
    }

    double sum = 0;
    for (int i = 0; i < meanSquaredErrors.length; i++) {
        sum += meanSquaredErrors[i];
    }
    return (double) sum / meanSquaredErrors.length;

}

From source file:eqtlmappingpipeline.interactionanalysis.InteractionPlotter.java

public InteractionPlotter(String interactionFile, String genotypeDir, String expressionDataFile,
        String covariateDataFile, String gteFile, String outdir) throws IOException {
    outdir = Gpio.formatAsDirectory(outdir);
    Gpio.createDir(outdir);//from   w  ww.j  a  v a  2 s.  co m
    Map<String, String> gte = null;
    if (gteFile != null) {
        TextFile tf = new TextFile(gteFile, TextFile.R);
        gte = tf.readAsHashMap(0, 1);
        tf.close();
    }

    HashSet<String> expressionProbes = new HashSet<String>();
    TextFile tf = new TextFile(interactionFile, TextFile.R);
    String[] elems = tf.readLineElems(TextFile.tab);
    ArrayList<Triple<String, String, String>> triples = new ArrayList<Triple<String, String, String>>();
    while (elems != null) {
        if (elems.length == 2) {
            String snp = elems[0];
            String probe = elems[1];
            expressionProbes.add(probe);
            triples.add(new Triple<String, String, String>(snp, null, probe));
        } else if (elems.length == 3) {
            String snp = elems[0];
            String covariate = elems[1];
            String probe = elems[2];
            expressionProbes.add(probe);
            triples.add(new Triple<String, String, String>(snp, covariate, probe));
        }
        elems = tf.readLineElems(TextFile.tab);
    }
    tf.close();
    System.out.println(triples.size() + " SNP - covariate - probe combinations read from: " + interactionFile);

    DoubleMatrixDataset<String, String> expressionData = new DoubleMatrixDataset<String, String>(
            expressionDataFile, expressionProbes);
    DoubleMatrixDataset<String, String> covariateData = new DoubleMatrixDataset<String, String>(
            covariateDataFile);

    int samplesHaveCovariatesOnCols = 0;
    int samplesHaveCovariatesOnRows = 0;
    for (int i = 0; i < expressionData.colObjects.size(); i++) {
        String expSample = expressionData.colObjects.get(i);
        Integer id1 = covariateData.hashCols.get(expSample);
        Integer id2 = covariateData.hashRows.get(expSample);
        if (id1 != null) {
            samplesHaveCovariatesOnCols++;
        }
        if (id2 != null) {
            samplesHaveCovariatesOnRows++;
        }
    }
    if (samplesHaveCovariatesOnRows > samplesHaveCovariatesOnCols) {
        System.out.println("Rows contain covariate samples in covariate file. Transposing covariates.");
        covariateData.transposeDataset();
    }

    TriTyperGenotypeData geno = new TriTyperGenotypeData(genotypeDir);
    SNPLoader loader = geno.createSNPLoader();

    int[] genotypeToCovariate = new int[geno.getIndividuals().length];
    int[] genotypeToExpression = new int[geno.getIndividuals().length];
    String[] genoIndividuals = geno.getIndividuals();
    for (int i = 0; i < genotypeToCovariate.length; i++) {
        String genoSample = genoIndividuals[i];
        if (geno.getIsIncluded()[i] != null && geno.getIsIncluded()[i]) {
            if (gte != null) {
                genoSample = gte.get(genoSample);
            }

            Integer covariateSample = covariateData.hashCols.get(genoSample);
            Integer expressionSample = expressionData.hashCols.get(genoSample);
            if (genoSample != null && covariateSample != null && expressionSample != null) {
                genotypeToCovariate[i] = covariateSample;
                genotypeToExpression[i] = expressionSample;
            } else {
                genotypeToCovariate[i] = -9;
                genotypeToExpression[i] = -9;

            }
        } else {
            genotypeToCovariate[i] = -9;
            genotypeToExpression[i] = -9;
        }
    }

    OLSMultipleLinearRegression regressionFullWithInteraction = new OLSMultipleLinearRegression();
    cern.jet.random.tdouble.StudentT tDistColt = null;
    org.apache.commons.math3.distribution.FDistribution fDist = null;
    cern.jet.random.tdouble.engine.DoubleRandomEngine randomEngine = null;

    Color[] colorarray = new Color[3];
    colorarray[0] = new Color(171, 178, 114);
    colorarray[1] = new Color(98, 175, 255);
    colorarray[2] = new Color(204, 86, 78);

    DecimalFormat decFormat = new DecimalFormat("#.###");
    DecimalFormat decFormatSmall = new DecimalFormat("0.#E0");
    for (Triple<String, String, String> triple : triples) {

        String snp = triple.getLeft();
        String covariate = triple.getMiddle();
        String probe = triple.getRight();

        Integer snpId = geno.getSnpToSNPId().get(snp);

        Integer probeId = expressionData.hashRows.get(probe);

        int startCovariate = -1;
        int endCovariate = -1;

        if (covariate == null) {
            startCovariate = 0;
            endCovariate = covariateData.nrRows;
        } else {
            Integer covariateId = covariateData.hashRows.get(covariate);
            if (covariateId != null) {
                startCovariate = covariateId;
                endCovariate = covariateId + 1;
            }
        }

        if (snpId >= 0 && probeId != null && startCovariate >= 0) {

            SNP snpObj = geno.getSNPObject(snpId);
            loader.loadGenotypes(snpObj);
            if (loader.hasDosageInformation()) {
                loader.loadDosage(snpObj);
            }

            double signInteractionEffectDirection = 1;
            String[] genotypeDescriptions = new String[3];
            if (snpObj.getAlleles()[1] == snpObj.getMinorAllele()) {
                signInteractionEffectDirection = -1;
                genotypeDescriptions[2] = BaseAnnot.toString(snpObj.getAlleles()[0]) + ""
                        + BaseAnnot.toString(snpObj.getAlleles()[0]);
                genotypeDescriptions[1] = BaseAnnot.toString(snpObj.getAlleles()[0]) + ""
                        + BaseAnnot.toString(snpObj.getAlleles()[1]);
                genotypeDescriptions[0] = BaseAnnot.toString(snpObj.getAlleles()[1]) + ""
                        + BaseAnnot.toString(snpObj.getAlleles()[1]);
            } else {
                genotypeDescriptions[0] = BaseAnnot.toString(snpObj.getAlleles()[0]) + ""
                        + BaseAnnot.toString(snpObj.getAlleles()[0]);
                genotypeDescriptions[1] = BaseAnnot.toString(snpObj.getAlleles()[0]) + ""
                        + BaseAnnot.toString(snpObj.getAlleles()[1]);
                genotypeDescriptions[2] = BaseAnnot.toString(snpObj.getAlleles()[1]) + ""
                        + BaseAnnot.toString(snpObj.getAlleles()[1]);

            }

            for (int q = startCovariate; q < endCovariate; q++) {
                System.out.println("Plotting: " + snp + "\t" + covariateData.rowObjects.get(q) + "\t" + probe);
                System.out.println(
                        "Individual\tAllele1\tAllele2\tGenotype\tGenotypeFlipped\tCovariate\tExpression");
                byte[] alleles1 = snpObj.getAllele1();
                byte[] alleles2 = snpObj.getAllele2();
                byte[] genotypes = snpObj.getGenotypes();
                ArrayList<Byte> genotypeArr = new ArrayList<Byte>();
                ArrayList<Double> covariateArr = new ArrayList<Double>();
                ArrayList<Double> expressionArr = new ArrayList<Double>();

                int nrCalled = 0;

                for (int i = 0; i < genoIndividuals.length; i++) {
                    if (genotypes[i] != -1 && genotypeToCovariate[i] != -9 && genotypeToExpression[i] != -9) {

                        if (!Double.isNaN(covariateData.rawData[q][genotypeToCovariate[i]])) {

                            int genotypeflipped = genotypes[i];
                            if (signInteractionEffectDirection == -1) {
                                genotypeflipped = 2 - genotypeflipped;
                            }

                            String output = genoIndividuals[i] + "\t" + BaseAnnot.toString(alleles1[i]) + "\t"
                                    + BaseAnnot.toString(alleles2[i]) + "\t" + genotypes[i] + "\t"
                                    + genotypeflipped + "\t" + covariateData.rawData[q][genotypeToCovariate[i]]
                                    + "\t" + expressionData.rawData[probeId][genotypeToExpression[i]];
                            System.out.println(output);

                            genotypeArr.add(genotypes[i]);

                            covariateArr.add(covariateData.rawData[q][genotypeToCovariate[i]]);
                            expressionArr.add(expressionData.rawData[probeId][genotypeToExpression[i]]);
                            nrCalled++;
                        }

                    }
                }
                System.out.println("");
                //Fill arrays with data in order to be able to perform the ordinary least squares analysis:
                double[] olsY = new double[nrCalled]; //Ordinary least squares: Our gene expression
                double[][] olsXFullWithInteraction = new double[nrCalled][3]; //With interaction term, linear model: y ~ a * SNP + b * CellCount + c + d * SNP * CellCount
                int itr = 0;

                double[] dataExp = new double[nrCalled];
                double[] dataCov = new double[nrCalled];
                int[] dataGen = new int[nrCalled];

                for (int s = 0; s < nrCalled; s++) {
                    byte originalGenotype = genotypeArr.get(s);
                    int genotype = originalGenotype;
                    if (signInteractionEffectDirection == -1) {
                        genotype = 2 - genotype;
                    }

                    olsY[s] = expressionArr.get(s);

                    olsXFullWithInteraction[s][0] = genotype;
                    olsXFullWithInteraction[s][1] = covariateArr.get(s);
                    olsXFullWithInteraction[s][2] = olsXFullWithInteraction[s][0]
                            * olsXFullWithInteraction[s][1];

                    dataExp[s] = olsY[s];
                    dataGen[s] = genotype;
                    dataCov[s] = covariateArr.get(s);

                    itr++;
                }

                regressionFullWithInteraction.newSampleData(olsY, olsXFullWithInteraction);

                double rss2 = regressionFullWithInteraction.calculateResidualSumOfSquares();
                double[] regressionParameters = regressionFullWithInteraction.estimateRegressionParameters();
                double[] regressionStandardErrors = regressionFullWithInteraction
                        .estimateRegressionParametersStandardErrors();

                double betaInteraction = regressionParameters[3];
                double seInteraction = regressionStandardErrors[3];
                double tInteraction = betaInteraction / seInteraction;
                double pValueInteraction = 1;
                double zScoreInteraction = 0;

                if (fDist == null) {
                    fDist = new org.apache.commons.math3.distribution.FDistribution((int) (3 - 2),
                            (int) (olsY.length - 3));
                    randomEngine = new cern.jet.random.tdouble.engine.DRand();
                    tDistColt = new cern.jet.random.tdouble.StudentT(olsY.length - 4, randomEngine);
                }

                if (tInteraction < 0) {
                    pValueInteraction = tDistColt.cdf(tInteraction);
                    if (pValueInteraction < 2.0E-323) {
                        pValueInteraction = 2.0E-323;
                    }
                    zScoreInteraction = cern.jet.stat.tdouble.Probability.normalInverse(pValueInteraction);
                } else {
                    pValueInteraction = tDistColt.cdf(-tInteraction);
                    if (pValueInteraction < 2.0E-323) {
                        pValueInteraction = 2.0E-323;
                    }

                    zScoreInteraction = -cern.jet.stat.tdouble.Probability.normalInverse(pValueInteraction);
                }
                pValueInteraction *= 2;
                String pvalFormatted = "";
                if (pValueInteraction >= 0.001) {
                    pvalFormatted = decFormat.format(pValueInteraction);
                } else {
                    pvalFormatted = decFormatSmall.format(pValueInteraction);
                }
                ScatterPlot scatterPlot = new ScatterPlot(500, 500, dataCov, dataExp, dataGen,
                        genotypeDescriptions, colorarray, ScatterPlot.OUTPUTFORMAT.PDF,
                        "Interaction between SNP " + snp + ", probe " + probe + " and covariate "
                                + covariateData.rowObjects.get(q),
                        "Z: " + decFormat.format(zScoreInteraction) + " Pvalue: " + pvalFormatted + " n: "
                                + nrCalled,
                        outdir + snp + "-" + probe + "-" + covariateData.rowObjects.get(q) + ".pdf", false);

            }

            snpObj.clearGenotypes();
        }
    }

    loader.close();
}