Java tutorial
/////////////////////////////////////////////////////////////////////////////// // For information as to what this class does, see the Javadoc, below. // // Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // // 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // // Ramsey, and Clark Glymour. // // // // This program is free software; you can redistribute it and/or modify // // it under the terms of the GNU General Public License as published by // // the Free Software Foundation; either version 2 of the License, or // // (at your option) any later version. // // // // This program is distributed in the hope that it will be useful, // // but WITHOUT ANY WARRANTY; without even the implied warranty of // // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // // GNU General Public License for more details. // // // // You should have received a copy of the GNU General Public License // // along with this program; if not, write to the Free Software // // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // /////////////////////////////////////////////////////////////////////////////// package edu.cmu.tetrad.util; import cern.colt.matrix.impl.DenseDoubleMatrix2D; import org.apache.commons.math3.linear.*; import java.io.IOException; import java.io.ObjectInputStream; /** * Wraps the Apache math3 linear algebra library for most uses in Tetrad. * Specialized uses will still have to use the library directly. One issue * this fixes is that a BlockRealMatrix cannot represent a matrix with zero * rows; this uses an Array2DRowRealMatrix to represent that case. * * @author Joseph Ramsey */ public class TetradMatrix1 implements TetradSerializable { static final long serialVersionUID = 23L; private RealMatrix apacheData; private int m, n; public TetradMatrix1(double[][] data) { if (data.length == 0) { this.apacheData = new Array2DRowRealMatrix(); } else { // this.apacheData = new OpenMapRealMatrix(data.length, data[0].length); // // for (int i = 0; i < data.length; i++) { // for (int j = 0; j < data[0].length; j++) { // apacheData.setEntry(i, j, data[i][j]); // } // } this.apacheData = new BlockRealMatrix(data); } this.m = data.length; this.n = m == 0 ? 0 : data[0].length; } public TetradMatrix1(int m, int n) { if (m == 0 || n == 0) { this.apacheData = new Array2DRowRealMatrix(); } else { // this.apacheData = new OpenMapRealMatrix(m, n); this.apacheData = new BlockRealMatrix(m, n); } this.m = m; this.n = n; } public TetradMatrix1(TetradMatrix1 m) { this(m.apacheData.getData()); } public TetradMatrix1(RealMatrix matrix) { if (matrix == null) { throw new IllegalArgumentException("Null matrix."); } this.apacheData = matrix; this.m = matrix.getRowDimension(); this.n = matrix.getColumnDimension(); } public TetradMatrix1(RealMatrix matrix, int rows, int columns) { if (matrix == null) { throw new IllegalArgumentException("Null matrix."); } this.apacheData = matrix; this.m = rows; this.n = columns; int _rows = matrix.getRowDimension(); int _cols = matrix.getColumnDimension(); if (_rows != 0 && _rows != rows) throw new IllegalArgumentException(); if (_cols != 0 && _cols != columns) throw new IllegalArgumentException(); } public static TetradMatrix1 sparseMatrix(int m, int n) { return new TetradMatrix1(new OpenMapRealMatrix(m, n)); } /** * Generates a simple exemplar of this class to test serialization. */ public static TetradMatrix1 serializableInstance() { return new TetradMatrix1(0, 0); } public TetradMatrix1 sqrt() { SingularValueDecomposition svd = new SingularValueDecomposition(getRealMatrix()); RealMatrix U = svd.getU(); RealMatrix V = svd.getV(); double[] s = svd.getSingularValues(); for (int i = 0; i < s.length; i++) s[i] = 1.0 / s[i]; RealMatrix S = new BlockRealMatrix(s.length, s.length); for (int i = 0; i < s.length; i++) S.setEntry(i, i, s[i]); RealMatrix sqrt = U.multiply(S).multiply(V); return new TetradMatrix1(sqrt); } public int rows() { return m; } public int columns() { return n; } public TetradMatrix1 getSelection(int[] rows, int[] cols) { if (rows.length == 0 || cols.length == 0) { return new TetradMatrix1(rows.length, cols.length); } RealMatrix subMatrix = apacheData.getSubMatrix(rows, cols); return new TetradMatrix1(subMatrix, rows.length, cols.length); } public TetradMatrix1 copy() { if (zeroDimension()) return new TetradMatrix1(rows(), columns()); return new TetradMatrix1(apacheData.copy(), rows(), columns()); } public TetradVector getColumn(int j) { if (zeroDimension()) { return new TetradVector(rows()); } return new TetradVector(apacheData.getColumn(j)); } public TetradMatrix1 times(TetradMatrix1 m) { if (this.zeroDimension() || m.zeroDimension()) return new TetradMatrix1(this.rows(), m.columns()); else { return new TetradMatrix1(apacheData.multiply(m.apacheData), this.rows(), m.columns()); } } public TetradVector times(TetradVector v) { if (v.size() != apacheData.getColumnDimension()) { throw new IllegalArgumentException("Mismatched dimensions."); } double[] y = new double[apacheData.getRowDimension()]; for (int i = 0; i < apacheData.getRowDimension(); i++) { double sum = 0.0; for (int j = 0; j < apacheData.getColumnDimension(); j++) { sum += apacheData.getEntry(i, j) * v.get(j); } y[i] = sum; } return new TetradVector(y); } public double[][] toArray() { return apacheData.getData(); } public double get(int i, int j) { return apacheData.getEntry(i, j); } public TetradMatrix1 like() { return new TetradMatrix1(apacheData.getRowDimension(), apacheData.getColumnDimension()); } public void set(int i, int j, double v) { apacheData.setEntry(i, j, v); } public TetradVector getRow(int i) { if (zeroDimension()) { return new TetradVector(columns()); } return new TetradVector(apacheData.getRow(i)); } public TetradMatrix1 getPart(int i, int j, int k, int l) { return new TetradMatrix1(apacheData.getSubMatrix(i, j, k, l)); } public TetradMatrix1 inverse() { if (!isSquare()) throw new IllegalArgumentException("I can only invert square matrices."); // Trying for a speedup by not having to construct the matrix factorization. if (rows() == 0) { return new TetradMatrix1(0, 0); } else if (rows() == 1) { TetradMatrix1 m = new TetradMatrix1(1, 1); m.set(0, 0, 1.0 / apacheData.getEntry(0, 0)); return m; } else if (rows() == 2) { double a = apacheData.getEntry(0, 0); double b = apacheData.getEntry(0, 1); double c = apacheData.getEntry(1, 0); double d = apacheData.getEntry(1, 1); double delta = a * d - b * c; TetradMatrix1 inverse = new TetradMatrix1(2, 2); inverse.set(0, 0, d); inverse.set(0, 1, -b); inverse.set(1, 0, -c); inverse.set(1, 1, a); return inverse.scalarMult(1.0 / delta); } else if (rows() == 3) { RealMatrix m = apacheData; double a11 = m.getEntry(0, 0); double a12 = m.getEntry(0, 1); double a13 = m.getEntry(0, 2); double a21 = m.getEntry(1, 0); double a22 = m.getEntry(1, 1); double a23 = m.getEntry(1, 2); double a31 = m.getEntry(2, 0); double a32 = m.getEntry(2, 1); double a33 = m.getEntry(2, 2); final double denom = -a12 * a21 * a33 + a11 * a22 * a33 - a13 * a22 * a31 + a12 * a23 * a31 + a13 * a21 * a32 - a11 * a23 * a32; double[][] inverse = new double[][] { { (a22 * a33 - a23 * a32) / denom, (-a12 * a33 + a13 * a32) / denom, (-a13 * a22 + a12 * a23) / denom }, { (-a21 * a33 + a23 * a31) / denom, (a11 * a33 - a13 * a31) / denom, (a13 * a21 - a11 * a23) / denom }, { (-a22 * a31 + a21 * a32) / denom, (a12 * a31 - a11 * a32) / denom, (-a12 * a21 + a11 * a22) / denom } }; return new TetradMatrix1(inverse); } else if (rows() == 4) { RealMatrix m = apacheData; double a11 = m.getEntry(0, 0); double a12 = m.getEntry(0, 1); double a13 = m.getEntry(0, 2); double a14 = m.getEntry(0, 3); double a21 = m.getEntry(1, 0); double a22 = m.getEntry(1, 1); double a23 = m.getEntry(1, 2); double a24 = m.getEntry(1, 3); double a31 = m.getEntry(2, 0); double a32 = m.getEntry(2, 1); double a33 = m.getEntry(2, 2); double a34 = m.getEntry(2, 3); double a41 = m.getEntry(3, 0); double a42 = m.getEntry(3, 1); double a43 = m.getEntry(3, 2); double a44 = m.getEntry(3, 3); final double denom = a14 * a23 * a32 * a41 - a13 * a24 * a32 * a41 - a14 * a22 * a33 * a41 + a12 * a24 * a33 * a41 + a13 * a22 * a34 * a41 - a12 * a23 * a34 * a41 - a14 * a23 * a31 * a42 + a13 * a24 * a31 * a42 + a14 * a21 * a33 * a42 - a11 * a24 * a33 * a42 - a13 * a21 * a34 * a42 + a11 * a23 * a34 * a42 + a14 * a22 * a31 * a43 - a12 * a24 * a31 * a43 - a14 * a21 * a32 * a43 + a11 * a24 * a32 * a43 + a12 * a21 * a34 * a43 - a11 * a22 * a34 * a43 - a13 * a22 * a31 * a44 + a12 * a23 * a31 * a44 + a13 * a21 * a32 * a44 - a11 * a23 * a32 * a44 - a12 * a21 * a33 * a44 + a11 * a22 * a33 * a44; double[][] inverse = new double[][] { { (-a24 * a33 * a42 + a23 * a34 * a42 + a24 * a32 * a43 - a22 * a34 * a43 - a23 * a32 * a44 + a22 * a33 * a44) / denom, (a14 * a33 * a42 - a13 * a34 * a42 - a14 * a32 * a43 + a12 * a34 * a43 + a13 * a32 * a44 - a12 * a33 * a44) / denom, (-a14 * a23 * a42 + a13 * a24 * a42 + a14 * a22 * a43 - a12 * a24 * a43 - a13 * a22 * a44 + a12 * a23 * a44) / denom, (a14 * a23 * a32 - a13 * a24 * a32 - a14 * a22 * a33 + a12 * a24 * a33 + a13 * a22 * a34 - a12 * a23 * a34) / denom }, { (a24 * a33 * a41 - a23 * a34 * a41 - a24 * a31 * a43 + a21 * a34 * a43 + a23 * a31 * a44 - a21 * a33 * a44) / denom, (-a14 * a33 * a41 + a13 * a34 * a41 + a14 * a31 * a43 - a11 * a34 * a43 - a13 * a31 * a44 + a11 * a33 * a44) / denom, (a14 * a23 * a41 - a13 * a24 * a41 - a14 * a21 * a43 + a11 * a24 * a43 + a13 * a21 * a44 - a11 * a23 * a44) / denom, (-a14 * a23 * a31 + a13 * a24 * a31 + a14 * a21 * a33 - a11 * a24 * a33 - a13 * a21 * a34 + a11 * a23 * a34) / denom }, { (-a24 * a32 * a41 + a22 * a34 * a41 + a24 * a31 * a42 - a21 * a34 * a42 - a22 * a31 * a44 + a21 * a32 * a44) / denom, (a14 * a32 * a41 - a12 * a34 * a41 - a14 * a31 * a42 + a11 * a34 * a42 + a12 * a31 * a44 - a11 * a32 * a44) / denom, (-a14 * a22 * a41 + a12 * a24 * a41 + a14 * a21 * a42 - a11 * a24 * a42 - a12 * a21 * a44 + a11 * a22 * a44) / denom, (a14 * a22 * a31 - a12 * a24 * a31 - a14 * a21 * a32 + a11 * a24 * a32 + a12 * a21 * a34 - a11 * a22 * a34) / denom }, { (a23 * a32 * a41 - a22 * a33 * a41 - a23 * a31 * a42 + a21 * a33 * a42 + a22 * a31 * a43 - a21 * a32 * a43) / denom, (-a13 * a32 * a41 + a12 * a33 * a41 + a13 * a31 * a42 - a11 * a33 * a42 - a12 * a31 * a43 + a11 * a32 * a43) / denom, (a13 * a22 * a41 - a12 * a23 * a41 - a13 * a21 * a42 + a11 * a23 * a42 + a12 * a21 * a43 - a11 * a22 * a43) / denom, (-a13 * a22 * a31 + a12 * a23 * a31 + a13 * a21 * a32 - a11 * a23 * a32 - a12 * a21 * a33 + a11 * a22 * a33) / denom } }; return new TetradMatrix1(inverse); } else { // Using LUDecomposition. // other options: QRDecomposition, CholeskyDecomposition, EigenDecomposition, QRDecomposition, // RRQRDDecomposition, SingularValueDecomposition. Very cool. Also MatrixUtils.blockInverse, // though that can't handle matrices of size 1. Many ways to invert. // Note CholeskyDecomposition only takes inverses of symmetric matrices. // return new TetradMatrix(new CholeskyDecomposition(apacheData).getSolver().getInverse()); // return new TetradMatrix(new EigenDecomposition(apacheData).getSolver().getInverse()); // return new TetradMatrix(new QRDecomposition(apacheData).getSolver().getInverse()); // // return new TetradMatrix(new SingularValueDecomposition(apacheData).getSolver().getInverse()); return new TetradMatrix1(new LUDecomposition(apacheData, 1e-9).getSolver().getInverse()); } } public TetradMatrix1 symmetricInverse() { if (!isSquare()) throw new IllegalArgumentException(); if (rows() == 0) return new TetradMatrix1(0, 0); // Using LUDecomposition. // other options: QRDecomposition, CholeskyDecomposition, EigenDecomposition, QRDecomposition, // RRQRDDecomposition, SingularValueDecomposition. Very cool. Also MatrixUtils.blockInverse, // though that can't handle matrices of size 1. Many ways to invert. // Note CholeskyDecomposition only takes inverses of symmetric matrices. return new TetradMatrix1(new CholeskyDecomposition(apacheData).getSolver().getInverse()); // return new TetradMatrix(new EigenDecomposition(apacheData).getSolver().getInverse()); // return new TetradMatrix(new QRDecomposition(apacheData).getSolver().getInverse()); // return new TetradMatrix(new SingularValueDecomposition(apacheData).getSolver().getInverse()); // return new TetradMatrix(new LUDecomposition(apacheData).getSolver().getInverse()); } public TetradMatrix1 ginverse() { final double[][] data = apacheData.getData(); if (data.length == 0 || data[0].length == 0) { return new TetradMatrix1(data); } return new TetradMatrix1(MatrixUtils.pseudoInverse(data)); } public static TetradMatrix1 identity(int rows) { TetradMatrix1 m = new TetradMatrix1(rows, rows); for (int i = 0; i < rows; i++) m.set(i, i, 1); return m; } public void assignRow(int row, TetradVector doubles) { apacheData.setRow(row, doubles.toArray()); } public void assignColumn(int row, TetradVector doubles) { apacheData.setColumn(row, doubles.toArray()); } public double trace() { return apacheData.getTrace(); } public double det() { return new LUDecomposition(apacheData).getDeterminant(); } public TetradMatrix1 transpose() { if (zeroDimension()) return new TetradMatrix1(columns(), rows()); return new TetradMatrix1(apacheData.transpose(), columns(), rows()); } public TetradMatrix1 transposeWithoutCopy() { RealMatrix transpose = MatrixUtils.transposeWithoutCopy(apacheData); return new TetradMatrix1(transpose); } private boolean zeroDimension() { return rows() == 0 || columns() == 0; } public boolean equals(TetradMatrix1 m, double tolerance) { RealMatrix n = m.apacheData; for (int i = 0; i < apacheData.getRowDimension(); i++) { for (int j = 0; j < apacheData.getColumnDimension(); j++) { if (Math.abs(apacheData.getEntry(i, j) - n.getEntry(i, j)) > tolerance) { return false; } } } return true; } public boolean isSquare() { return rows() == columns(); } public boolean isSymmetric(double tolerance) { return edu.cmu.tetrad.util.MatrixUtils.isSymmetric(apacheData.getData(), tolerance); } public double zSum() { return new DenseDoubleMatrix2D(apacheData.getData()).zSum(); } public TetradMatrix1 minus(TetradMatrix1 mb) { if (mb.rows() == 0 || mb.columns() == 0) return this; return new TetradMatrix1(apacheData.subtract(mb.apacheData), rows(), columns()); } public TetradMatrix1 plus(TetradMatrix1 mb) { if (mb.rows() == 0 || mb.columns() == 0) return this; return new TetradMatrix1(apacheData.add(mb.apacheData), rows(), columns()); } public TetradMatrix1 scalarMult(double scalar) { return new TetradMatrix1(apacheData.copy().scalarMultiply(scalar), rows(), columns()); } public int rank() { // return new RRQRDecomposition(apacheData).getRank(10); SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(apacheData); return singularValueDecomposition.getRank(); } public double norm1() { return apacheData.getNorm(); } public TetradVector diag() { double[] diag = new double[apacheData.getRowDimension()]; for (int i = 0; i < apacheData.getRowDimension(); i++) { diag[i] = apacheData.getEntry(i, i); } return new TetradVector(diag); } public String toString() { if (rows() == 0) { return "Empty"; } else { return MatrixUtils.toString(toArray()); } } /** * Adds semantic checks to the default deserialization method. This method * must have the standard signature for a readObject method, and the body of * the method must begin with "s.defaultReadObject();". Other than that, any * semantic checks can be specified and do not need to stay the same from * version to version. A readObject method of this form may be added to any * class, even if Tetrad sessions were previously saved out using a version * of the class that didn't include it. (That's what the * "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for help. * * @throws java.io.IOException * @throws ClassNotFoundException */ private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException { s.defaultReadObject(); if (m == 0) m = apacheData.getRowDimension(); if (n == 0) n = apacheData.getColumnDimension(); } public RealMatrix getRealMatrix() { return apacheData; } public void assign(TetradMatrix1 matrix) { if (apacheData.getRowDimension() != matrix.rows() || apacheData.getColumnDimension() != matrix.columns()) { throw new IllegalArgumentException("Mismatched matrix size."); } for (int i = 0; i < apacheData.getRowDimension(); i++) { for (int j = 0; j < apacheData.getColumnDimension(); j++) { apacheData.setEntry(i, j, matrix.get(i, j)); } } } public TetradVector sum(int direction) { if (direction == 1) { TetradVector sums = new TetradVector(columns()); for (int j = 0; j < columns(); j++) { double sum = 0.0; for (int i = 0; i < rows(); i++) { sum += apacheData.getEntry(i, j); } sums.set(j, sum); } return sums; } else if (direction == 2) { TetradVector sums = new TetradVector(rows()); for (int i = 0; i < rows(); i++) { double sum = 0.0; for (int j = 0; j < columns(); j++) { sum += apacheData.getEntry(i, j); } sums.set(i, sum); } return sums; } else { throw new IllegalArgumentException("Expecting 1 (sum columns) or 2 (sum rows)."); } } }