Example usage for org.apache.commons.math3.linear ConjugateGradient ConjugateGradient

List of usage examples for org.apache.commons.math3.linear ConjugateGradient ConjugateGradient

Introduction

In this page you can find the example usage for org.apache.commons.math3.linear ConjugateGradient ConjugateGradient.

Prototype

public ConjugateGradient(final IterationManager manager, final double delta, final boolean check)
        throws NullArgumentException 

Source Link

Document

Creates a new instance of this class, with default stopping criterion and custom iteration manager.

Usage

From source file:fp.overlapr.algorithmen.StressMajorization.java

/**
 * Fhrt die Stress-Majorization durch. siehe: Gansner, Koren, North: Graph
 * Drawing by Stress Majorization, 2004// w  ww. j  a v a 2  s .  co m
 * 
 * @param graph
 *            Graph, dessen Knoten-Positionen neu berechnet werden sollen
 * @param d
 *            Matrix, die die idealen Distanzen (d_ij) zwischen den Knoten
 *            enthlt
 * @return Matrix, die die neuen x- und y-Werte der einzelnen Knoten enthlt
 */
public static double[][] doStressMajorization(Graph graph, double[][] d) {

    int iter = 0;

    // X holen
    Array2DRowRealMatrix X = new Array2DRowRealMatrix(graph.getKnotenAnzahl(), 2);
    for (int i = 0; i < graph.getKnotenAnzahl(); i++) {
        X.setEntry(i, 0, graph.getKnoten().get(i).getX());
        X.setEntry(i, 1, graph.getKnoten().get(i).getY());
    }

    // D holen
    Array2DRowRealMatrix D = new Array2DRowRealMatrix(d);

    // W berechnen
    Array2DRowRealMatrix W = new Array2DRowRealMatrix(D.getRowDimension(), D.getColumnDimension());
    W.walkInRowOrder(new DefaultRealMatrixChangingVisitor() {
        @Override
        public double visit(int row, int column, double value) {
            if (D.getEntry(row, column) == 0) {
                return 0.0;
            } else {
                return 1.0 / (D.getEntry(row, column) * D.getEntry(row, column));
            }
        }
    });

    // LW berechnen
    Array2DRowRealMatrix LW = new Array2DRowRealMatrix(D.getRowDimension(), D.getColumnDimension());
    LW.walkInRowOrder(new DefaultRealMatrixChangingVisitor() {
        @Override
        public double visit(int row, int column, double value) {
            if (row != column) {
                return (-1) * W.getEntry(row, column);
            } else {

                return value;
            }
        }
    });
    LW.walkInRowOrder(new DefaultRealMatrixChangingVisitor() {
        @Override
        public double visit(int row, int column, double value) {
            if (row == column) {

                double sum = 0;

                for (int k = 0; k < LW.getColumnDimension(); k++) {
                    if (k != row) {
                        sum = sum + W.getEntry(row, k);
                    }
                }

                return sum;
            } else {

                return value;
            }
        }
    });

    double[][] x = null;

    while (iter < ITER) {

        iter++;

        // LX berechnen
        Array2DRowRealMatrix LX = new Array2DRowRealMatrix(D.getRowDimension(), D.getColumnDimension());
        LX.walkInRowOrder(new DefaultRealMatrixChangingVisitor() {
            @Override
            public double visit(int row, int column, double value) {
                if (row != column) {

                    // norm 2
                    double term1 = FastMath.pow((X.getEntry(row, 0) - X.getEntry(column, 0)), 2);
                    double term2 = FastMath.pow((X.getEntry(row, 1) - X.getEntry(column, 1)), 2);

                    double abst = Math.sqrt(term1 + term2);

                    return (-1) * W.getEntry(row, column) * D.getEntry(row, column) * inv(abst);
                } else {
                    return value;
                }
            }
        });
        LX.walkInRowOrder(new DefaultRealMatrixChangingVisitor() {
            @Override
            public double visit(int row, int column, double value) {
                if (row == column) {

                    double sum = 0;

                    for (int k = 0; k < LX.getColumnDimension(); k++) {
                        if (k != row) {
                            sum = sum + LX.getEntry(row, k);
                        }
                    }
                    return (-1) * sum;
                } else {
                    return value;
                }
            }
        });

        /*
         * Lineare Gleichungssysteme lsen
         */
        // x-Werte holen
        ArrayRealVector xWerte = new ArrayRealVector(X.getColumn(0));

        // y-Werte holen
        ArrayRealVector yWerte = new ArrayRealVector(X.getColumn(1));

        // b_x berechnen
        ArrayRealVector b_x = (ArrayRealVector) LX.operate(xWerte);

        // b_y berechnen
        ArrayRealVector b_y = (ArrayRealVector) LX.operate(yWerte);

        /*
         * CG-Verfahren anwenden
         */
        // neue x-Werte berechnen mittels PCG
        // xWerte = conjugateGradientsMethod(LW, b_x, xWerte);

        // neue y-Werte berechnen mittels PCG
        // yWerte = conjugateGradientsMethod(LW, b_y, yWerte);

        ConjugateGradient cg = new ConjugateGradient(Integer.MAX_VALUE, TOL, false);
        xWerte = (ArrayRealVector) cg.solve(LW, JacobiPreconditioner.create(LW), b_x);
        yWerte = (ArrayRealVector) cg.solve(LW, JacobiPreconditioner.create(LW), b_y);

        /*
         * neue Positiones-Werte zurckgeben
         */
        x = new double[X.getRowDimension()][2];
        for (int i = 0; i < x.length; i++) {

            x[i][0] = xWerte.getEntry(i);
            x[i][1] = yWerte.getEntry(i);

            X.setEntry(i, 0, xWerte.getEntry(i));
            X.setEntry(i, 1, yWerte.getEntry(i));

        }
    }

    return x;
}

From source file:edu.duke.cs.osprey.tupexp.IterativeCGTupleFitter.java

@Override
double[] doFit() {
    //return fit tuple coefficients

    ConjugateGradient cg = new ConjugateGradient(100000, 1e-6, false);//max_iter; delta; whether to check pos def
    //delta is target ratio of residual norm to true vals norm

    long startTime = System.currentTimeMillis();

    while (true) {
        double iterStartTime = System.currentTimeMillis();

        Atb = calcRHS();//  www.j a va 2 s . com
        RealVector ans = cg.solve(AtA, Atb);
        double[] newFitVals = calcFitVals(ans);

        System.out.println(
                "Conjugate gradient fitting time (ms): " + (System.currentTimeMillis() - iterStartTime));

        //boolean done = checkDone(curFitVals, newFitVals);
        double resid = calcResidual(newFitVals);
        System.out.println("Step residual: " + resid);

        if (resid > curResid) {//gotten worse...use previous vals
            System.out.println("Iterative conjugate gradient fitting time (ms): "
                    + (System.currentTimeMillis() - startTime));
            return curCoeffs.toArray();
        } else if (resid > curResid - 1e-4) {//basically converged
            System.out.println("Iterative conjugate gradient fitting time (ms): "
                    + (System.currentTimeMillis() - startTime));
            return ans.toArray();
        } else {//keep going
            curCoeffs = ans;
            curFitVals = newFitVals;
            curResid = resid;
        }
    }
}

From source file:org.eclipse.dataset.LinearAlgebra.java

/**
 * Calculation A x = v by conjugate gradient method with the stopping criterion being
 * that the estimated residual r = v - A x satisfies ||r|| < delta ||v||
 * @param a/*from  ww  w . j  a  va 2 s  .  com*/
 * @param v
 * @param maxIterations
 * @param delta parameter used by stopping criterion
 * @return solution of A^-1 v by conjugate gradient method
 */
public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) {
    ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false);
    return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v)));
}