Example usage for org.apache.commons.math3.stat.correlation PearsonsCorrelation getCorrelationMatrix

List of usage examples for org.apache.commons.math3.stat.correlation PearsonsCorrelation getCorrelationMatrix

Introduction

In this page you can find the example usage for org.apache.commons.math3.stat.correlation PearsonsCorrelation getCorrelationMatrix.

Prototype

public RealMatrix getCorrelationMatrix() 

Source Link

Document

Returns the correlation matrix

Usage

From source file:de.tudarmstadt.ukp.experiments.argumentation.convincingness.sampling.Step6GraphTransitivityCleaner.java

public GraphCleaningResults processSingleFile(File file, File outputDir, String prefix,
        Boolean collectGeneratedArgumentPairs) throws Exception {
    GraphCleaningResults result = new GraphCleaningResults();

    File outFileTable = new File(outputDir, prefix + file.getName() + "_table.csv");
    File outFileInfo = new File(outputDir, prefix + file.getName() + "_info.txt");

    PrintStream psTable = new PrintStream(new FileOutputStream(outFileTable));
    PrintStream psInfo = new PrintStream(new FileOutputStream(outFileInfo));

    // load one topic/side
    List<AnnotatedArgumentPair> pairs = new ArrayList<>(
            (List<AnnotatedArgumentPair>) XStreamTools.getXStream().fromXML(file));

    int fullDataSize = pairs.size();

    // filter out missing gold data
    Iterator<AnnotatedArgumentPair> iterator = pairs.iterator();
    while (iterator.hasNext()) {
        AnnotatedArgumentPair pair = iterator.next();
        if (pair.getGoldLabel() == null) {
            iterator.remove();/*from   w  w w . j  a  v a2  s.  c  o m*/
        }
        // or we want to completely remove equal edges in advance!
        else if (this.removeEqualEdgesParam && "equal".equals(pair.getGoldLabel())) {
            iterator.remove();
        }
    }

    // sort pairs by their weight
    this.argumentPairListSorter.sortArgumentPairs(pairs);

    int preFilteredDataSize = pairs.size();

    // compute correlation between score threshold and number of removed edges
    double[] correlationEdgeWeights = new double[pairs.size()];
    double[] correlationRemovedEdges = new double[pairs.size()];

    // only cycles of length 0 to 5 are interesting (5+ are too big)
    Range<Integer> range = Range.between(0, 5);

    psTable.print(
            "EdgeWeightThreshold\tPairs\tignoredEdgesCount\tIsDAG\tTransitivityScoreMean\tTransitivityScoreMax\tTransitivityScoreSamples\tEdges\tNodes\t");
    for (int j = range.getMinimum(); j <= range.getMaximum(); j++) {
        psTable.print("Cycles_" + j + "\t");
    }
    psTable.println();

    // store the indices of all pairs (edges) that have been successfully added without
    // generating cycles
    TreeSet<Integer> addedPairsIndices = new TreeSet<>();

    // number of edges ignored as they generated cycles
    int ignoredEdgesCount = 0;

    Graph lastGraph = null;

    // flag that the first cycle was already processed
    boolean firstCycleAlreadyHit = false;

    for (int i = 1; i < pairs.size(); i++) {
        // now filter the finalArgumentPairList and add only pairs that have not generated cycles
        List<AnnotatedArgumentPair> subList = new ArrayList<>();

        for (Integer index : addedPairsIndices) {
            subList.add(pairs.get(index));
        }

        // and add the current at the end
        subList.add(pairs.get(i));

        // what is the current lowest value of a pair weight?
        double weakestEdgeWeight = computeEdgeWeight(subList.get(subList.size() - 1), LAMBDA_PENALTY);

        //            Graph graph = buildGraphFromArgumentPairs(finalArgumentPairList);
        int numberOfLoops;

        // map for storing cycles by their length
        TreeMap<Integer, TreeSet<String>> lengthCyclesMap = new TreeMap<>();

        Graph graph = buildGraphFromArgumentPairs(subList);

        lastGraph = graph;

        List<List<Object>> cyclesInGraph = findCyclesInGraph(graph);

        DescriptiveStatistics transitivityScore = new DescriptiveStatistics();

        if (cyclesInGraph.isEmpty()) {
            // we have DAG
            transitivityScore = computeTransitivityScores(graph);

            // update results
            result.maxTransitivityScore = (int) transitivityScore.getMax();
            result.avgTransitivityScore = transitivityScore.getMean();
        }

        numberOfLoops = cyclesInGraph.size();

        // initialize map
        for (int r = range.getMinimum(); r <= range.getMaximum(); r++) {
            lengthCyclesMap.put(r, new TreeSet<String>());
        }

        // we hit a loop
        if (numberOfLoops > 0) {
            // let's update the result

            if (!firstCycleAlreadyHit) {
                result.graphSizeEdgesBeforeFirstCycle = graph.getEdgeCount();
                result.graphSizeNodesBeforeFirstCycle = graph.getNodeCount();

                // find the shortest cycle
                int shortestCycleLength = Integer.MAX_VALUE;

                for (List<Object> cycle : cyclesInGraph) {
                    shortestCycleLength = Math.min(shortestCycleLength, cycle.size());
                }
                result.lengthOfFirstCircle = shortestCycleLength;

                result.pairsBeforeFirstCycle = i;

                firstCycleAlreadyHit = true;
            }

            // ignore this edge further
            ignoredEdgesCount++;

            // update counts of different cycles lengths
            for (List<Object> cycle : cyclesInGraph) {
                int currentSize = cycle.size();

                // convert to sorted set of nodes
                List<String> cycleAsSortedIDs = new ArrayList<>();
                for (Object o : cycle) {
                    cycleAsSortedIDs.add(o.toString());
                }
                Collections.sort(cycleAsSortedIDs);

                if (range.contains(currentSize)) {
                    lengthCyclesMap.get(currentSize).add(cycleAsSortedIDs.toString());
                }
            }
        } else {
            addedPairsIndices.add(i);
        }

        // we hit the first cycle

        // collect loop sizes
        StringBuilder loopsAsString = new StringBuilder();
        for (int j = range.getMinimum(); j <= range.getMaximum(); j++) {
            //                    loopsAsString.append(j).append(":");
            loopsAsString.append(lengthCyclesMap.get(j).size());
            loopsAsString.append("\t");
        }

        psTable.printf(Locale.ENGLISH, "%.4f\t%d\t%d\t%b\t%.2f\t%d\t%d\t%d\t%d\t%s%n", weakestEdgeWeight, i,
                ignoredEdgesCount, numberOfLoops == 0,
                Double.isNaN(transitivityScore.getMean()) ? 0d : transitivityScore.getMean(),
                (int) transitivityScore.getMax(), transitivityScore.getN(), graph.getEdgeCount(),
                graph.getNodeCount(), loopsAsString.toString().trim());

        // update result
        result.finalGraphSizeEdges = graph.getEdgeCount();
        result.finalGraphSizeNodes = graph.getNodeCount();
        result.ignoredEdgesThatBrokeDAG = ignoredEdgesCount;

        // update stats for correlation
        correlationEdgeWeights[i] = weakestEdgeWeight;
        //            correlationRemovedEdges[i] =  (double) ignoredEdgesCount;
        // let's try: if we keep = 0, if we remove = 1
        correlationRemovedEdges[i] = numberOfLoops == 0 ? 0.0 : 1.0;
    }

    psInfo.println("Original: " + fullDataSize + ", removed by MACE: " + (fullDataSize - preFilteredDataSize)
            + ", final: " + (preFilteredDataSize - ignoredEdgesCount) + " (removed: " + ignoredEdgesCount
            + ")");

    double[][] matrix = new double[correlationEdgeWeights.length][];
    for (int i = 0; i < correlationEdgeWeights.length; i++) {
        matrix[i] = new double[2];
        matrix[i][0] = correlationEdgeWeights[i];
        matrix[i][1] = correlationRemovedEdges[i];
    }

    PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation(matrix);

    double pValue = pearsonsCorrelation.getCorrelationPValues().getEntry(0, 1);
    double correlation = pearsonsCorrelation.getCorrelationMatrix().getEntry(0, 1);

    psInfo.printf(Locale.ENGLISH, "Correlation: %.3f, p-Value: %.4f%n", correlation, pValue);
    if (lastGraph == null) {
        throw new IllegalStateException("Graph is null");
    }

    // close
    psInfo.close();
    psTable.close();

    // save filtered final gold data
    List<AnnotatedArgumentPair> finalArgumentPairList = new ArrayList<>();

    for (Integer index : addedPairsIndices) {
        finalArgumentPairList.add(pairs.get(index));
    }
    XStreamTools.toXML(finalArgumentPairList, new File(outputDir, prefix + file.getName()));

    // TODO: here, we can add newly generated edges from graph transitivity
    if (collectGeneratedArgumentPairs) {
        Set<GeneratedArgumentPair> generatedArgumentPairs = new HashSet<>();
        // collect all arguments
        Map<String, Argument> allArguments = new HashMap<>();
        for (ArgumentPair argumentPair : pairs) {
            allArguments.put(argumentPair.getArg1().getId(), argumentPair.getArg1());
            allArguments.put(argumentPair.getArg2().getId(), argumentPair.getArg2());
        }

        Graph finalGraph = buildGraphFromArgumentPairs(finalArgumentPairList);
        for (Edge e : finalGraph.getEdgeSet()) {
            e.setAttribute(WEIGHT, 1.0);
        }

        for (Node j : finalGraph) {
            for (Node k : finalGraph) {
                if (j != k) {
                    // is there a path between?
                    BellmanFord bfShortest = new BellmanFord(WEIGHT, j.getId());
                    bfShortest.init(finalGraph);
                    bfShortest.compute();

                    Path shortestPath = bfShortest.getShortestPath(k);

                    if (shortestPath.size() > 0) {
                        // we have a path
                        GeneratedArgumentPair ap = new GeneratedArgumentPair();
                        Argument arg1 = allArguments.get(j.getId());

                        if (arg1 == null) {
                            throw new IllegalStateException("Cannot find argument " + j.getId());
                        }
                        ap.setArg1(arg1);

                        Argument arg2 = allArguments.get(k.getId());

                        if (arg2 == null) {
                            throw new IllegalStateException("Cannot find argument " + k.getId());
                        }
                        ap.setArg2(arg2);

                        ap.setGoldLabel("a1");
                        generatedArgumentPairs.add(ap);
                    }
                }
            }
        }
        // and now add the reverse ones
        Set<GeneratedArgumentPair> generatedReversePairs = new HashSet<>();
        for (GeneratedArgumentPair pair : generatedArgumentPairs) {
            GeneratedArgumentPair ap = new GeneratedArgumentPair();
            ap.setArg1(pair.getArg2());
            ap.setArg2(pair.getArg1());
            ap.setGoldLabel("a2");
            generatedReversePairs.add(ap);
        }
        generatedArgumentPairs.addAll(generatedReversePairs);
        // and save it
        XStreamTools.toXML(generatedArgumentPairs, new File(outputDir, "generated_" + prefix + file.getName()));
    }

    result.fullPairsSize = fullDataSize;
    result.removedApriori = (fullDataSize - preFilteredDataSize);
    result.finalPairsRetained = finalArgumentPairList.size();

    // save the final graph
    Graph outGraph = cleanCopyGraph(lastGraph);
    FileSinkDGS dgs1 = new FileSinkDGS();
    File outFile = new File(outputDir, prefix + file.getName() + ".dgs");

    System.out.println("Saved to " + outFile);
    FileWriter w1 = new FileWriter(outFile);

    dgs1.writeAll(outGraph, w1);
    w1.close();

    return result;
}

From source file:org.apache.solr.client.solrj.io.eval.CorrelationEvaluator.java

@Override
public Object doWork(Object... values) throws IOException {

    if (values.length == 2) {
        Object first = values[0];
        Object second = values[1];

        if (null == first) {
            throw new IOException(
                    String.format(Locale.ROOT, "Invalid expression %s - null found for the first value",
                            toExpression(constructingFactory)));
        }//from  w  w w. java 2 s  .  c  om
        if (null == second) {
            throw new IOException(
                    String.format(Locale.ROOT, "Invalid expression %s - null found for the second value",
                            toExpression(constructingFactory)));
        }
        if (!(first instanceof List<?>)) {
            throw new IOException(String.format(Locale.ROOT,
                    "Invalid expression %s - found type %s for the first value, expecting a list of numbers",
                    toExpression(constructingFactory), first.getClass().getSimpleName()));
        }
        if (!(second instanceof List<?>)) {
            throw new IOException(String.format(Locale.ROOT,
                    "Invalid expression %s - found type %s for the second value, expecting a list of numbers",
                    toExpression(constructingFactory), first.getClass().getSimpleName()));
        }

        if (type.equals(CorrelationType.pearsons)) {
            PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
            return pearsonsCorrelation.correlation(
                    ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
                    ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue())
                            .toArray());
        } else if (type.equals(CorrelationType.kendalls)) {
            KendallsCorrelation kendallsCorrelation = new KendallsCorrelation();
            return kendallsCorrelation.correlation(
                    ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
                    ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue())
                            .toArray());

        } else if (type.equals(CorrelationType.spearmans)) {
            SpearmansCorrelation spearmansCorrelation = new SpearmansCorrelation();
            return spearmansCorrelation.correlation(
                    ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
                    ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue())
                            .toArray());
        } else {
            return null;
        }
    } else if (values.length == 1) {
        if (values[0] instanceof Matrix) {
            Matrix matrix = (Matrix) values[0];
            double[][] data = matrix.getData();
            if (type.equals(CorrelationType.pearsons)) {
                PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation(data);
                RealMatrix corrMatrix = pearsonsCorrelation.getCorrelationMatrix();
                double[][] corrMatrixData = corrMatrix.getData();
                Matrix realMatrix = new Matrix(corrMatrixData);
                realMatrix.addToContext("corr", pearsonsCorrelation);
                return realMatrix;
            } else if (type.equals(CorrelationType.kendalls)) {
                KendallsCorrelation kendallsCorrelation = new KendallsCorrelation(data);
                RealMatrix corrMatrix = kendallsCorrelation.getCorrelationMatrix();
                double[][] corrMatrixData = corrMatrix.getData();
                Matrix realMatrix = new Matrix(corrMatrixData);
                realMatrix.addToContext("corr", kendallsCorrelation);
                return realMatrix;
            } else if (type.equals(CorrelationType.spearmans)) {
                SpearmansCorrelation spearmansCorrelation = new SpearmansCorrelation(
                        new Array2DRowRealMatrix(data));
                RealMatrix corrMatrix = spearmansCorrelation.getCorrelationMatrix();
                double[][] corrMatrixData = corrMatrix.getData();
                Matrix realMatrix = new Matrix(corrMatrixData);
                realMatrix.addToContext("corr", spearmansCorrelation.getRankCorrelation());
                return realMatrix;
            } else {
                return null;
            }
        } else {
            throw new IOException(
                    "corr function operates on either two numeric arrays or a single matrix as parameters.");
        }
    } else {
        throw new IOException(
                "corr function operates on either two numeric arrays or a single matrix as parameters.");
    }
}

From source file:org.meteoinfo.math.stats.StatsUtil.java

/**
 * Calculates a Pearson correlation coefficient.
 *
 * @param x X data//from  w  ww.j  ava 2 s  .c o  m
 * @param y Y data
 * @return Pearson correlation and p-value.
 */
public static double[] pearsonr(Array x, Array y) {
    int m = x.getShape()[0];
    int n = 1;
    double[][] aa = new double[m][n * 2];
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n * 2; j++) {
            if (j < n) {
                aa[i][j] = x.getDouble(i * n + j);
            } else {
                aa[i][j] = y.getDouble(i * n + j - n);
            }
        }
    }
    RealMatrix matrix = new Array2DRowRealMatrix(aa, false);
    PearsonsCorrelation pc = new PearsonsCorrelation(matrix);
    double r = pc.getCorrelationMatrix().getEntry(0, 1);
    double pvalue = pc.getCorrelationPValues().getEntry(0, 1);
    return new double[] { r, pvalue };
}

From source file:restclient.service.DailyRecordFacadeREST.java

@GET
@Path("findCorrelationByStartAndEndDateAndWeatherVariable/{startDate}/{endDate}/{attribute}")
@Produces({ "application/json" })
public String findCorrelationByStartAndEndDateAndWeatherVariable(@PathParam("startDate") Date startDate,
        @PathParam("endDate") Date endDate, @PathParam("attribute") String attribute) {
    String attributeInLowerCase = attribute.toLowerCase();
    switch (attributeInLowerCase) {
    case "windspeed":
    case "wind speed":
        attributeInLowerCase = "windSpeed";
        break;//from  w w w.  j  a v a 2  s .  c  om
    case "atmosphericpressure":
    case "atmospheric pressure":
        attributeInLowerCase = "atmosphericPressure";
        break;
    default:
        break;
    }
    String sth = "NEW restclient.Result2(d.painLevel, d." + attributeInLowerCase + ")";
    String jpql = "SELECT " + sth
            + " FROM restclient.DailyRecord d WHERE d.recordDate >= :startDate AND d.recordDate <= :endDate";
    TypedQuery<Result2> q = em.createQuery(jpql, Result2.class);
    q.setParameter("startDate", startDate);
    q.setParameter("endDate", endDate);
    List<Result2> result = q.getResultList();
    double data[][] = new double[result.size()][];
    for (int i = 0; i < result.size(); i++) {
        data[i] = new double[] { result.get(i).painLevel, result.get(i).weather };
    }
    RealMatrix m = MatrixUtils.createRealMatrix(data);
    String first = "";
    for (int i = 0; i < m.getColumnDimension(); i++)
        for (int j = 0; j < m.getColumnDimension(); j++) {
            PearsonsCorrelation pc = new PearsonsCorrelation();
            double cor = pc.correlation(m.getColumn(i), m.getColumn(j));
            first += (i + "," + j + "=[" + String.format(".%2f", cor) + "," + "]" + ";   ");
        }
    PearsonsCorrelation pc = new PearsonsCorrelation(m);
    RealMatrix corM = pc.getCorrelationMatrix();
    String second = ("!correlation:" + corM.getEntry(0, 1) + "   ");
    RealMatrix pM = pc.getCorrelationPValues();
    String third = ("!p value:" + pM.getEntry(0, 1));
    return first + second + third;
}

From source file:restclient.service.RecordFacadeREST.java

@GET
@Path("findCorrelation/{uid}/{sdate}/{edate}/{wvariable}")
@Produces({ "application/json" })
public List<Correlation> findCorrelation(@PathParam("uid") Integer uid, @PathParam("sdate") String date1,
        @PathParam("edate") String date2, @PathParam("wvariable") String wv) throws ParseException {
    SimpleDateFormat sdf1 = new SimpleDateFormat("yyyy-MM-dd");
    SimpleDateFormat sdf2 = new SimpleDateFormat("dd/MM/yyyy");
    Date sdate = sdf1.parse(date1);
    Date edate = sdf1.parse(date2);
    TypedQuery<Record> q = em.createQuery(
            "SELECT r FROM Record r WHERE r.date >= :sdate AND r.date <= :edate AND r.uid.uid = :uid order by r.date ASC",
            Record.class);
    q.setParameter("uid", uid);
    q.setParameter("sdate", sdate);
    q.setParameter("edate", edate);
    List<Record> qr = q.getResultList();
    List<Correlation> re = new ArrayList<Correlation>();
    Correlation cl = new Correlation();
    double data[][] = new double[qr.size()][2];
    for (int i = 0; i < qr.size(); ++i) {
        Record r = qr.get(i);/*from   ww  w  . ja va  2s  . com*/
        data[i][0] = r.getPlevel();
        if (wv.equals("temperature")) {
            data[i][1] = r.getTemp();
        } else if (wv.equals("humidity")) {
            data[i][1] = r.getHumidity();
        } else if (wv.equals("windspeed")) {
            data[i][1] = r.getWindspeed();
        } else {
            data[i][1] = r.getPressure();
        }
    }
    RealMatrix m = MatrixUtils.createRealMatrix(data);
    PearsonsCorrelation pc = new PearsonsCorrelation(m);
    RealMatrix corM = pc.getCorrelationMatrix();
    cl.setRvalue(corM.getEntry(0, 1));
    RealMatrix pM = pc.getCorrelationPValues();
    cl.setSvalue(pM.getEntry(0, 1));
    re.add(cl);
    return re;
}