List of usage examples for org.apache.commons.math3.stat.regression OLSMultipleLinearRegression OLSMultipleLinearRegression
OLSMultipleLinearRegression
From source file:modelcreation.ModelCreation.java
/** * @param args the command line arguments */// w w w .j av a 2 s . c o m public static void main(String[] args) { int size = writeDataIntoFile(); double[][] x = new double[size][2]; double[] y = new double[size]; readDataFromFile(x, y); // TTest tTest = new TTest(); // System.out.println("p value for home value = " + tTest.tTest(x[0], y)); // System.out.println("p value for away value = " + tTest.tTest(x[1], y)); // System.out.println("Average mean squared error: " + apply10FoldCrossValidation(x, y)); // double[] predictions = new double[size]; // for (int i = 0; i < size; i++) { // predictions[i] = 0.5622255342802198 + (1.0682845275289186E-9 * x[i][0]) + (-9.24614306976538E-10 * x[i][1]); // // //System.out.print("Actual: " + y[i]); // //System.out.println(" Predicted: " + predicted); // } // // System.out.println(calculateMeanSquaredError(y, predictions)); OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); regression.newSampleData(y, x); regression.setNoIntercept(true); printRegressionStatistics(regression); //Team[] teams2014 = getTeams(354); //Team[] teams2015 = getTeams(398, 2015); //Team[] teams = concatTeams(teams2014, teams2015); // HashMap<Integer, ArrayList<Integer>> marketValueGoalsDataset = createMarketValueGoalsDataset(teams2014); // // SimpleRegression regression = new SimpleRegression(); // // Set<Integer> marketValues = marketValueGoalsDataset.keySet(); // for (Integer marketValue:marketValues) { // ArrayList<Integer> goals = marketValueGoalsDataset.get(marketValue); // int totalGoals = 0; // for(Integer goal:goals) { // regression.addData(marketValue, goal); // totalGoals += goal; // } // double avg = (double) totalGoals / goals.size(); // System.out.println("Team Value: " + marketValue + ", Goal Average: " + avg); // } // // System.out.println("Intercept: " + regression.getIntercept()); // System.out.println("Slope: " + regression.getSlope()); // System.out.println("R^2: " + regression.getRSquare()); //LinearRegression.calculateLinearRegression(marketValueGoalsDataset); }
From source file:dase.timeseries.analysis.GrangerTest.java
/** * Returns p-value for Granger causality test. * * @param y//from w w w. j a va2 s .co m * - predictable variable * @param x * - predictor * @param L * - lag, should be 1 or greater. * @return p-value of Granger causality */ public static double granger(double[] y, double[] x, int L) { OLSMultipleLinearRegression h0 = new OLSMultipleLinearRegression(); OLSMultipleLinearRegression h1 = new OLSMultipleLinearRegression(); double[][] laggedY = createLaggedSide(L, y); double[][] laggedXY = createLaggedSide(L, x, y); int n = laggedY.length; h0.newSampleData(strip(L, y), laggedY); h1.newSampleData(strip(L, y), laggedXY); double rs0[] = h0.estimateResiduals(); double rs1[] = h1.estimateResiduals(); double RSS0 = sqrSum(rs0); double RSS1 = sqrSum(rs1); double ftest = ((RSS0 - RSS1) / L) / (RSS1 / (n - 2 * L - 1)); System.out.println(RSS0 + " " + RSS1); System.out.println("F-test " + ftest); FDistribution fDist = new FDistribution(L, n - 2 * L - 1); double pValue = 1.0 - fDist.cumulativeProbability(ftest); System.out.println("P-value " + pValue); return pValue; }
From source file:net.gtl.movieanalytics.model.LinearRegression.java
private double[] estimateParameter(double[][] x, double[] y) { //printTestData(x, y); OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression(); ols.newSampleData(y, x);/*from ww w . ja va2 s . c om*/ return ols.estimateRegressionParameters(); }
From source file:com.insightml.models.regression.OLS.java
@Override public IModel<Sample, Double> train(final double[][] features, final double[] expected, final String[] featureNames) { final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); regression.newSampleData(expected, features); return new LinearRegressionModel(regression.estimateRegressionParameters(), featureNames); }
From source file:com.davidbracewell.ml.regression.LeastSquaresLearner.java
@Override protected void trainAll(List<Instance> trainingData) { OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); double[] y = new double[trainingData.size()]; double[][] x = new double[trainingData.size()][]; int i = 0;/*from w w w .j a va2s. c om*/ for (Instance datum : trainingData) { y[i] = datum.getTargetValue(); x[i] = datum.toArray(); i++; } regression.newSampleData(y, x); double[] params = regression.estimateRegressionParameters(); model.bias = params[0]; double[] weights = new double[params.length - 1]; System.arraycopy(params, 1, weights, 0, params.length - 1); model.weights = new DenseVector(weights); }
From source file:lu.lippmann.cdb.lab.regression.Regression.java
/** * Constructor./* w ww . j a va 2 s. c o m*/ */ public Regression(final Instances ds, final int idx) throws Exception { this.newds = WekaDataProcessingUtil.buildDataSetSortedByAttribute(ds, idx); //System.out.println("Regression -> "+newds.toSummaryString()); final int N = this.newds.numInstances(); final int M = this.newds.numAttributes(); final double[][] x = new double[N][M - 1]; final double[] y = new double[N]; for (int i = 0; i < N; i++) { y[i] = this.newds.instance(i).value(0); } for (int i = 0; i < N; i++) { for (int j = 1; j < M; j++) { x[i][j - 1] = this.newds.instance(i).value(j); } } final OLSMultipleLinearRegression reg = new OLSMultipleLinearRegression(); //reg.setNoIntercept(true); reg.newSampleData(y, x); this.r2 = reg.calculateRSquared(); //this.r2=-1d; this.coe = reg.estimateRegressionParameters(); this.estims = calculateEstimations(x, y, coe); }
From source file:com.mebigfatguy.damus.main.DamusCalculator.java
private boolean calcLinearRegression() { try {/*from w w w .j a v a2 s . c om*/ Context context = Context.instance(); PredictionModel model = context.getPredictionModel(); TrainingData data = context.getTrainingData(); int numMetrics = model.getNumMetrics(); int numResults = model.getNumResults(); int numItems = data.getNumItems(); OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); double[][] xx = new double[numItems][]; for (int i = 0; i < numItems; i++) { double[] x = new double[numMetrics]; for (int m = 0; m < numMetrics; m++) { Metric metric = model.getMetric(m); double value; switch (metric.getType()) { case Percent: value = ((Number) data.getItem(i).getValue(metric)).doubleValue(); break; case Real: value = ((BigDecimal) data.getItem(i).getValue(metric)).doubleValue(); break; case YesNo: value = ((Boolean) data.getItem(i).getValue(metric)).booleanValue() ? 1.0 : 0.0; break; default: value = 0.0; break; } x[m] = value; } xx[i] = x; } for (int r = 0; r < numResults; r++) { Metric metric = model.getResult(r); double[] y = new double[numItems]; for (int i = 0; i < numItems; i++) { double value; switch (metric.getType()) { case Percent: value = ((Number) data.getItem(i).getValue(metric)).doubleValue(); break; case Real: value = ((BigDecimal) data.getItem(i).getValue(metric)).doubleValue(); break; case YesNo: value = ((Boolean) data.getItem(i).getValue(metric)).booleanValue() ? 1.0 : 0.0; break; default: value = 0.0; break; } y[i] = value; } regression.newSampleData(y, xx); double result = regression.estimateRegressandVariance(); trainingItem.setValue(metric, new BigDecimal(result)); } return true; } catch (Exception e) { return false; } }
From source file:msi.gama.util.GamaRegression.java
public GamaRegression(final IScope scope, final GamaFloatMatrix data, final String method) throws Exception { AbstractMultipleLinearRegression regressionMethod = null; if (method.equals("GLS")) regressionMethod = new GLSMultipleLinearRegression(); else/* ww w . jav a 2 s.c o m*/ regressionMethod = new OLSMultipleLinearRegression(); final int nbFeatures = data.numCols - 1; final int nbInstances = data.numRows; final double[] instances = new double[data.numCols * data.numRows]; for (int i = 0; i < data.length(scope); i++) { instances[i] = data.getMatrix()[i]; } regressionMethod.newSampleData(instances, nbInstances, nbFeatures); param = regressionMethod.estimateRegressionParameters(); }
From source file:modelcreation.ModelCreation.java
public static double apply10FoldCrossValidation(double[][] x, double[] y) { int subSize = y.length / 10; ArrayList<Integer> indeces = new ArrayList(); for (int i = 0; i < y.length; i++) { indeces.add(i);/*from w w w . j a v a 2 s .co m*/ } Collections.shuffle(indeces); double[] meanSquaredErrors = new double[10]; int count = 0; for (int i = 0; i < 10; i++) { System.out.println("-------------Fold " + i + "--------------"); double[][] subXTest = new double[subSize][2]; double[] subYTest = new double[subSize]; double[][] subXTraining = new double[y.length - subSize][2]; double[] subYTraining = new double[y.length - subSize]; for (int j = 0; j < i * subSize; j++) { int index = indeces.get(count); count++; subXTraining[j][0] = x[index][0]; subXTraining[j][1] = x[index][1]; subYTraining[j] = y[index]; } for (int j = 0; j < subSize; j++) { int index = indeces.get(count); count++; subXTest[j][0] = x[index][0]; subXTest[j][1] = x[index][1]; subYTest[j] = y[index]; } for (int j = i * subSize; j < y.length - subSize; j++) { int index = indeces.get(count); count++; subXTraining[j][0] = x[index][0]; subXTraining[j][1] = x[index][1]; subYTraining[j] = y[index]; } count = 0; OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); regression.newSampleData(subYTraining, subXTraining); regression.setNoIntercept(true); meanSquaredErrors[i] = evaluateModel(regression, subXTest, subYTest); } double sum = 0; for (int i = 0; i < meanSquaredErrors.length; i++) { sum += meanSquaredErrors[i]; } return (double) sum / meanSquaredErrors.length; }
From source file:eqtlmappingpipeline.interactionanalysis.InteractionPlotter.java
public InteractionPlotter(String interactionFile, String genotypeDir, String expressionDataFile, String covariateDataFile, String gteFile, String outdir) throws IOException { outdir = Gpio.formatAsDirectory(outdir); Gpio.createDir(outdir);//from w ww.j a v a 2 s. co m Map<String, String> gte = null; if (gteFile != null) { TextFile tf = new TextFile(gteFile, TextFile.R); gte = tf.readAsHashMap(0, 1); tf.close(); } HashSet<String> expressionProbes = new HashSet<String>(); TextFile tf = new TextFile(interactionFile, TextFile.R); String[] elems = tf.readLineElems(TextFile.tab); ArrayList<Triple<String, String, String>> triples = new ArrayList<Triple<String, String, String>>(); while (elems != null) { if (elems.length == 2) { String snp = elems[0]; String probe = elems[1]; expressionProbes.add(probe); triples.add(new Triple<String, String, String>(snp, null, probe)); } else if (elems.length == 3) { String snp = elems[0]; String covariate = elems[1]; String probe = elems[2]; expressionProbes.add(probe); triples.add(new Triple<String, String, String>(snp, covariate, probe)); } elems = tf.readLineElems(TextFile.tab); } tf.close(); System.out.println(triples.size() + " SNP - covariate - probe combinations read from: " + interactionFile); DoubleMatrixDataset<String, String> expressionData = new DoubleMatrixDataset<String, String>( expressionDataFile, expressionProbes); DoubleMatrixDataset<String, String> covariateData = new DoubleMatrixDataset<String, String>( covariateDataFile); int samplesHaveCovariatesOnCols = 0; int samplesHaveCovariatesOnRows = 0; for (int i = 0; i < expressionData.colObjects.size(); i++) { String expSample = expressionData.colObjects.get(i); Integer id1 = covariateData.hashCols.get(expSample); Integer id2 = covariateData.hashRows.get(expSample); if (id1 != null) { samplesHaveCovariatesOnCols++; } if (id2 != null) { samplesHaveCovariatesOnRows++; } } if (samplesHaveCovariatesOnRows > samplesHaveCovariatesOnCols) { System.out.println("Rows contain covariate samples in covariate file. Transposing covariates."); covariateData.transposeDataset(); } TriTyperGenotypeData geno = new TriTyperGenotypeData(genotypeDir); SNPLoader loader = geno.createSNPLoader(); int[] genotypeToCovariate = new int[geno.getIndividuals().length]; int[] genotypeToExpression = new int[geno.getIndividuals().length]; String[] genoIndividuals = geno.getIndividuals(); for (int i = 0; i < genotypeToCovariate.length; i++) { String genoSample = genoIndividuals[i]; if (geno.getIsIncluded()[i] != null && geno.getIsIncluded()[i]) { if (gte != null) { genoSample = gte.get(genoSample); } Integer covariateSample = covariateData.hashCols.get(genoSample); Integer expressionSample = expressionData.hashCols.get(genoSample); if (genoSample != null && covariateSample != null && expressionSample != null) { genotypeToCovariate[i] = covariateSample; genotypeToExpression[i] = expressionSample; } else { genotypeToCovariate[i] = -9; genotypeToExpression[i] = -9; } } else { genotypeToCovariate[i] = -9; genotypeToExpression[i] = -9; } } OLSMultipleLinearRegression regressionFullWithInteraction = new OLSMultipleLinearRegression(); cern.jet.random.tdouble.StudentT tDistColt = null; org.apache.commons.math3.distribution.FDistribution fDist = null; cern.jet.random.tdouble.engine.DoubleRandomEngine randomEngine = null; Color[] colorarray = new Color[3]; colorarray[0] = new Color(171, 178, 114); colorarray[1] = new Color(98, 175, 255); colorarray[2] = new Color(204, 86, 78); DecimalFormat decFormat = new DecimalFormat("#.###"); DecimalFormat decFormatSmall = new DecimalFormat("0.#E0"); for (Triple<String, String, String> triple : triples) { String snp = triple.getLeft(); String covariate = triple.getMiddle(); String probe = triple.getRight(); Integer snpId = geno.getSnpToSNPId().get(snp); Integer probeId = expressionData.hashRows.get(probe); int startCovariate = -1; int endCovariate = -1; if (covariate == null) { startCovariate = 0; endCovariate = covariateData.nrRows; } else { Integer covariateId = covariateData.hashRows.get(covariate); if (covariateId != null) { startCovariate = covariateId; endCovariate = covariateId + 1; } } if (snpId >= 0 && probeId != null && startCovariate >= 0) { SNP snpObj = geno.getSNPObject(snpId); loader.loadGenotypes(snpObj); if (loader.hasDosageInformation()) { loader.loadDosage(snpObj); } double signInteractionEffectDirection = 1; String[] genotypeDescriptions = new String[3]; if (snpObj.getAlleles()[1] == snpObj.getMinorAllele()) { signInteractionEffectDirection = -1; genotypeDescriptions[2] = BaseAnnot.toString(snpObj.getAlleles()[0]) + "" + BaseAnnot.toString(snpObj.getAlleles()[0]); genotypeDescriptions[1] = BaseAnnot.toString(snpObj.getAlleles()[0]) + "" + BaseAnnot.toString(snpObj.getAlleles()[1]); genotypeDescriptions[0] = BaseAnnot.toString(snpObj.getAlleles()[1]) + "" + BaseAnnot.toString(snpObj.getAlleles()[1]); } else { genotypeDescriptions[0] = BaseAnnot.toString(snpObj.getAlleles()[0]) + "" + BaseAnnot.toString(snpObj.getAlleles()[0]); genotypeDescriptions[1] = BaseAnnot.toString(snpObj.getAlleles()[0]) + "" + BaseAnnot.toString(snpObj.getAlleles()[1]); genotypeDescriptions[2] = BaseAnnot.toString(snpObj.getAlleles()[1]) + "" + BaseAnnot.toString(snpObj.getAlleles()[1]); } for (int q = startCovariate; q < endCovariate; q++) { System.out.println("Plotting: " + snp + "\t" + covariateData.rowObjects.get(q) + "\t" + probe); System.out.println( "Individual\tAllele1\tAllele2\tGenotype\tGenotypeFlipped\tCovariate\tExpression"); byte[] alleles1 = snpObj.getAllele1(); byte[] alleles2 = snpObj.getAllele2(); byte[] genotypes = snpObj.getGenotypes(); ArrayList<Byte> genotypeArr = new ArrayList<Byte>(); ArrayList<Double> covariateArr = new ArrayList<Double>(); ArrayList<Double> expressionArr = new ArrayList<Double>(); int nrCalled = 0; for (int i = 0; i < genoIndividuals.length; i++) { if (genotypes[i] != -1 && genotypeToCovariate[i] != -9 && genotypeToExpression[i] != -9) { if (!Double.isNaN(covariateData.rawData[q][genotypeToCovariate[i]])) { int genotypeflipped = genotypes[i]; if (signInteractionEffectDirection == -1) { genotypeflipped = 2 - genotypeflipped; } String output = genoIndividuals[i] + "\t" + BaseAnnot.toString(alleles1[i]) + "\t" + BaseAnnot.toString(alleles2[i]) + "\t" + genotypes[i] + "\t" + genotypeflipped + "\t" + covariateData.rawData[q][genotypeToCovariate[i]] + "\t" + expressionData.rawData[probeId][genotypeToExpression[i]]; System.out.println(output); genotypeArr.add(genotypes[i]); covariateArr.add(covariateData.rawData[q][genotypeToCovariate[i]]); expressionArr.add(expressionData.rawData[probeId][genotypeToExpression[i]]); nrCalled++; } } } System.out.println(""); //Fill arrays with data in order to be able to perform the ordinary least squares analysis: double[] olsY = new double[nrCalled]; //Ordinary least squares: Our gene expression double[][] olsXFullWithInteraction = new double[nrCalled][3]; //With interaction term, linear model: y ~ a * SNP + b * CellCount + c + d * SNP * CellCount int itr = 0; double[] dataExp = new double[nrCalled]; double[] dataCov = new double[nrCalled]; int[] dataGen = new int[nrCalled]; for (int s = 0; s < nrCalled; s++) { byte originalGenotype = genotypeArr.get(s); int genotype = originalGenotype; if (signInteractionEffectDirection == -1) { genotype = 2 - genotype; } olsY[s] = expressionArr.get(s); olsXFullWithInteraction[s][0] = genotype; olsXFullWithInteraction[s][1] = covariateArr.get(s); olsXFullWithInteraction[s][2] = olsXFullWithInteraction[s][0] * olsXFullWithInteraction[s][1]; dataExp[s] = olsY[s]; dataGen[s] = genotype; dataCov[s] = covariateArr.get(s); itr++; } regressionFullWithInteraction.newSampleData(olsY, olsXFullWithInteraction); double rss2 = regressionFullWithInteraction.calculateResidualSumOfSquares(); double[] regressionParameters = regressionFullWithInteraction.estimateRegressionParameters(); double[] regressionStandardErrors = regressionFullWithInteraction .estimateRegressionParametersStandardErrors(); double betaInteraction = regressionParameters[3]; double seInteraction = regressionStandardErrors[3]; double tInteraction = betaInteraction / seInteraction; double pValueInteraction = 1; double zScoreInteraction = 0; if (fDist == null) { fDist = new org.apache.commons.math3.distribution.FDistribution((int) (3 - 2), (int) (olsY.length - 3)); randomEngine = new cern.jet.random.tdouble.engine.DRand(); tDistColt = new cern.jet.random.tdouble.StudentT(olsY.length - 4, randomEngine); } if (tInteraction < 0) { pValueInteraction = tDistColt.cdf(tInteraction); if (pValueInteraction < 2.0E-323) { pValueInteraction = 2.0E-323; } zScoreInteraction = cern.jet.stat.tdouble.Probability.normalInverse(pValueInteraction); } else { pValueInteraction = tDistColt.cdf(-tInteraction); if (pValueInteraction < 2.0E-323) { pValueInteraction = 2.0E-323; } zScoreInteraction = -cern.jet.stat.tdouble.Probability.normalInverse(pValueInteraction); } pValueInteraction *= 2; String pvalFormatted = ""; if (pValueInteraction >= 0.001) { pvalFormatted = decFormat.format(pValueInteraction); } else { pvalFormatted = decFormatSmall.format(pValueInteraction); } ScatterPlot scatterPlot = new ScatterPlot(500, 500, dataCov, dataExp, dataGen, genotypeDescriptions, colorarray, ScatterPlot.OUTPUTFORMAT.PDF, "Interaction between SNP " + snp + ", probe " + probe + " and covariate " + covariateData.rowObjects.get(q), "Z: " + decFormat.format(zScoreInteraction) + " Pvalue: " + pvalFormatted + " n: " + nrCalled, outdir + snp + "-" + probe + "-" + covariateData.rowObjects.get(q) + ".pdf", false); } snpObj.clearGenotypes(); } } loader.close(); }