Java tutorial
// Copyright (C) 2014 Guibing Guo // // This file is part of LibRec. // // LibRec is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // LibRec is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with LibRec. If not, see <http://www.gnu.org/licenses/>. // package librec.data; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Set; import librec.util.Stats; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.HashBasedTable; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Table; import com.google.common.collect.Table.Cell; /** * Data Structure: Sparse Matrix whose implementation is modified from M4J library * * <ul> * <li><a href="http://netlib.org/linalg/html_templates/node91.html">Compressed Row Storage (CRS)</a></li> * <li><a href="http://netlib.org/linalg/html_templates/node92.html">Compressed Col Storage (CCS)</a></li> * </ul> * * @author guoguibing * */ public class SparseMatrix implements Iterable<MatrixEntry>, Serializable { private static final long serialVersionUID = 8024536511172609539L; // matrix dimension protected int numRows, numColumns; // Compressed Row Storage (CRS) protected double[] rowData; protected int[] rowPtr, colInd; // Compressed Col Storage (CCS) protected double[] colData; protected int[] colPtr, rowInd; /** * Construct a sparse matrix with both CRS and CCS structures */ public SparseMatrix(int rows, int cols, Table<Integer, Integer, ? extends Number> dataTable, Multimap<Integer, Integer> colMap) { numRows = rows; numColumns = cols; construct(dataTable, colMap); } /** * Construct a sparse matrix with only CRS structures */ public SparseMatrix(int rows, int cols, Table<Integer, Integer, ? extends Number> dataTable) { this(rows, cols, dataTable, null); } /** * Define a sparse matrix without data, only use for {@code transpose} method * */ private SparseMatrix(int rows, int cols) { numRows = rows; numColumns = cols; } /** * Construct a sparse matrix from another sparse matrix * * @param mat * the original sparse matrix * @param deap * whether to copy the CCS structures */ public SparseMatrix(SparseMatrix mat) { numRows = mat.numRows; numColumns = mat.numColumns; copyCRS(mat.rowData, mat.rowPtr, mat.colInd); copyCCS(mat.colData, mat.colPtr, mat.rowInd); } private void copyCRS(double[] data, int[] ptr, int[] idx) { rowData = new double[data.length]; for (int i = 0; i < rowData.length; i++) rowData[i] = data[i]; rowPtr = new int[ptr.length]; for (int i = 0; i < rowPtr.length; i++) rowPtr[i] = ptr[i]; colInd = new int[idx.length]; for (int i = 0; i < colInd.length; i++) colInd[i] = idx[i]; } private void copyCCS(double[] data, int[] ptr, int[] idx) { colData = new double[data.length]; for (int i = 0; i < colData.length; i++) colData[i] = data[i]; colPtr = new int[ptr.length]; for (int i = 0; i < colPtr.length; i++) colPtr[i] = ptr[i]; rowInd = new int[idx.length]; for (int i = 0; i < rowInd.length; i++) rowInd[i] = idx[i]; } /** * Make a deep clone of current matrix */ public SparseMatrix clone() { return new SparseMatrix(this); } /** * @return the transpose of current matrix */ public SparseMatrix transpose() { SparseMatrix tr = new SparseMatrix(numColumns, numRows); tr.copyCRS(this.rowData, this.rowPtr, this.colInd); tr.copyCCS(this.colData, this.colPtr, this.rowInd); return tr; } /** * @return the row pointers of CRS structure */ public int[] getRowPointers() { return rowPtr; } /** * @return the column indices of CCS structure */ public int[] getColumnIndices() { return colInd; } /** * @return the cardinary of current matrix */ public int size() { int size = 0; for (MatrixEntry me : this) if (me.get() != 0) size++; return size; } /** * @return the data table of this matrix as (row, column, value) cells */ public Table<Integer, Integer, Double> getDataTable() { Table<Integer, Integer, Double> res = HashBasedTable.create(); for (MatrixEntry me : this) { if (me.get() != 0) res.put(me.row(), me.column(), me.get()); } return res; } /** * Construct a sparse matrix * * @param dataTable * data table * @param columnStructure * column structure */ private void construct(Table<Integer, Integer, ? extends Number> dataTable, Multimap<Integer, Integer> columnStructure) { int nnz = dataTable.size(); // CRS rowPtr = new int[numRows + 1]; colInd = new int[nnz]; rowData = new double[nnz]; int j = 0; for (int i = 1; i <= numRows; ++i) { Set<Integer> cols = dataTable.row(i - 1).keySet(); rowPtr[i] = rowPtr[i - 1] + cols.size(); for (int col : cols) { colInd[j++] = col; if (col < 0 || col >= numColumns) throw new IllegalArgumentException( "colInd[" + j + "]=" + col + ", which is not a valid column index"); } Arrays.sort(colInd, rowPtr[i - 1], rowPtr[i]); } // CCS colPtr = new int[numColumns + 1]; rowInd = new int[nnz]; colData = new double[nnz]; j = 0; for (int i = 1; i <= numColumns; ++i) { // dataTable.col(i-1) is more time-consuming than columnStructure.get(i-1) Collection<Integer> rows = columnStructure != null ? columnStructure.get(i - 1) : dataTable.column(i - 1).keySet(); colPtr[i] = colPtr[i - 1] + rows.size(); for (int row : rows) { rowInd[j++] = row; if (row < 0 || row >= numRows) throw new IllegalArgumentException( "rowInd[" + j + "]=" + row + ", which is not a valid row index"); } Arrays.sort(rowInd, colPtr[i - 1], colPtr[i]); } // set data for (Cell<Integer, Integer, ? extends Number> en : dataTable.cellSet()) { int row = en.getRowKey(); int col = en.getColumnKey(); double val = en.getValue().doubleValue(); set(row, col, val); } } /** * @return number of rows */ public int numRows() { return numRows; } /** * @return number of columns */ public int numColumns() { return numColumns; } /** * @return referce to the data of current matrix */ public double[] getData() { return rowData; } /** * Set a value to entry [row, column] * * @param row * row id * @param column * column id * @param val * value to set */ public void set(int row, int column, double val) { int index = getCRSIndex(row, column); rowData[index] = val; index = getCCSIndex(row, column); colData[index] = val; } /** * Add a value to entry [row, column] * * @param row * row id * @param column * column id * @param val * value to add */ public void add(int row, int column, double val) { int index = getCRSIndex(row, column); rowData[index] += val; index = getCCSIndex(row, column); colData[index] += val; } /** * Retrieve value at entry [row, column] * * @param row * row id * @param column * column id * @return value at entry [row, column] */ public double get(int row, int column) { int index = Arrays.binarySearch(colInd, rowPtr[row], rowPtr[row + 1], column); if (index >= 0) return rowData[index]; else return 0; } /** * get a row sparse vector of a matrix * * @param row * row id * @return a sparse vector of {index, value} * */ public SparseVector row(int row) { SparseVector sv = new SparseVector(numColumns); if (row < numRows) { for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) { int col = colInd[j]; double val = get(row, col); if (val != 0.0) sv.set(col, val); } } // return an empty vector if the row does not exist in training matrix return sv; } /** * get columns of a specific row where (row, column) entries are non-zero * * @param row * row id * @return a list of column index */ public List<Integer> getColumns(int row) { List<Integer> res = new ArrayList<>(); if (row < numRows) { for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) { int col = colInd[j]; double val = get(row, col); if (val != 0.0) res.add(col); } } return res; } /** * create a row cache of a matrix in {row, row-specific vector} * * @param cacheSpec * cache specification * @return a matrix row cache in {row, row-specific vector} */ public LoadingCache<Integer, SparseVector> rowCache(String cacheSpec) { LoadingCache<Integer, SparseVector> cache = CacheBuilder.from(cacheSpec) .build(new CacheLoader<Integer, SparseVector>() { @Override public SparseVector load(Integer rowId) throws Exception { return row(rowId); } }); return cache; } /** * create a row cache of a matrix in {row, row-specific columns} * * @param cacheSpec * cache specification * @return a matrix row cache in {row, row-specific columns} */ public LoadingCache<Integer, List<Integer>> rowColumnsCache(String cacheSpec) { LoadingCache<Integer, List<Integer>> cache = CacheBuilder.from(cacheSpec) .build(new CacheLoader<Integer, List<Integer>>() { @Override public List<Integer> load(Integer rowId) throws Exception { return getColumns(rowId); } }); return cache; } /** * create a column cache of a matrix * * @param cacheSpec * cache specification * @return a matrix column cache */ public LoadingCache<Integer, SparseVector> columnCache(String cacheSpec) { LoadingCache<Integer, SparseVector> cache = CacheBuilder.from(cacheSpec) .build(new CacheLoader<Integer, SparseVector>() { @Override public SparseVector load(Integer columnId) throws Exception { return column(columnId); } }); return cache; } /** * create a row cache of a matrix in {row, row-specific columns} * * @param cacheSpec * cache specification * @return a matrix row cache in {row, row-specific columns} */ public LoadingCache<Integer, List<Integer>> columnRowsCache(String cacheSpec) { LoadingCache<Integer, List<Integer>> cache = CacheBuilder.from(cacheSpec) .build(new CacheLoader<Integer, List<Integer>>() { @Override public List<Integer> load(Integer colId) throws Exception { return getRows(colId); } }); return cache; } /** * get a row sparse vector of a matrix * * @param row * row id * @param except * row id to be excluded * @return a sparse vector of {index, value} * */ public SparseVector row(int row, int except) { SparseVector sv = new SparseVector(numColumns); for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) { int col = colInd[j]; if (col != except) { double val = get(row, col); if (val != 0.0) sv.set(col, val); } } return sv; } /** * query the size of a specific row * * @param row * row id * @return the size of non-zero elements of a row */ public int rowSize(int row) { int size = 0; for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) { int col = colInd[j]; if (get(row, col) != 0.0) size++; } return size; } /** * @return a list of rows which have at least one non-empty entry */ public List<Integer> rows() { List<Integer> list = new ArrayList<>(); for (int row = 0; row < numRows; row++) { for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) { int col = colInd[j]; if (get(row, col) != 0.0) { list.add(row); break; } } } return list; } /** * get a col sparse vector of a matrix * * @param col * col id * @return a sparse vector of {index, value} * */ public SparseVector column(int col) { SparseVector sv = new SparseVector(numRows); if (col < numColumns) { for (int j = colPtr[col]; j < colPtr[col + 1]; j++) { int row = rowInd[j]; double val = get(row, col); if (val != 0.0) sv.set(row, val); } } // return an empty vector if the column does not exist in training // matrix return sv; } /** * query the size of a specific col * * @param col * col id * @return the size of non-zero elements of a row */ public int columnSize(int col) { int size = 0; for (int j = colPtr[col]; j < colPtr[col + 1]; j++) { int row = rowInd[j]; double val = get(row, col); if (val != 0.0) size++; } return size; } /** * get rows of a specific column where (row, column) entries are non-zero * * @param col * column id * @return a list of column index */ public List<Integer> getRows(int col) { List<Integer> res = new ArrayList<>(); if (col < numColumns) { for (int j = colPtr[col]; j < colPtr[col + 1]; j++) { int row = rowInd[j]; double val = get(row, col); if (val != 0.0) res.add(row); } } return res; } /** * @return a list of columns which have at least one non-empty entry */ public List<Integer> columns() { List<Integer> list = new ArrayList<>(); for (int col = 0; col < numColumns; col++) { for (int j = colPtr[col]; j < colPtr[col + 1]; j++) { int row = rowInd[j]; double val = get(row, col); if (val != 0.0) { list.add(col); break; } } } return list; } /** * @return sum of matrix data */ public double sum() { return Stats.sum(rowData); } /** * @return mean of matrix data */ public double mean() { return sum() / size(); } /** * Normalize the matrix entries to (0, 1) by (x-min)/(max-min) * * @param min * minimum value * @param max * maximum value */ public void normalize(double min, double max) { assert max > min; for (MatrixEntry me : this) { double entry = me.get(); if (entry != 0) me.set((entry - min) / (max - min)); } } /** * Normalize the matrix entries to (0, 1) by (x/max) * * @param max * maximum value */ public void normalize(double max) { normalize(0, max); } /** * Standardize the matrix entries by row- or column-wise z-scores (z=(x-u)/sigma) * * @param isByRow * standardize by row if true; otherwise by column */ public void standardize(boolean isByRow) { int iters = isByRow ? numRows : numColumns; for (int iter = 0; iter < iters; iter++) { SparseVector vec = isByRow ? row(iter) : column(iter); if (vec.getCount() > 0) { double[] data = vec.getData(); double mu = Stats.mean(data); double sigma = Stats.sd(data, mu); for (VectorEntry ve : vec) { int idx = ve.index(); double val = ve.get(); double z = (val - mu) / sigma; if (isByRow) this.set(iter, idx, z); else this.set(idx, iter, z); } } } } /** * remove zero entries of the given matrix */ public static void reshape(SparseMatrix mat) { SparseMatrix res = new SparseMatrix(mat.numRows, mat.numColumns); int nnz = mat.size(); // Compressed Row Storage (CRS) res.rowData = new double[nnz]; res.colInd = new int[nnz]; res.rowPtr = new int[mat.numRows + 1]; // handle row data int index = 0; for (int i = 1; i < mat.rowPtr.length; i++) { for (int j = mat.rowPtr[i - 1]; j < mat.rowPtr[i]; j++) { // row i-1, row 0 always starts with 0 double val = mat.rowData[j]; int col = mat.colInd[j]; if (val != 0) { res.rowData[index] = val; res.colInd[index] = col; index++; } } res.rowPtr[i] = index; } // Compressed Col Storage (CCS) res.colData = new double[nnz]; res.rowInd = new int[nnz]; res.colPtr = new int[mat.numColumns + 1]; // handle column data index = 0; for (int j = 1; j < mat.colPtr.length; j++) { for (int i = mat.colPtr[j - 1]; i < mat.colPtr[j]; i++) { // column j-1, index i double val = mat.colData[i]; int row = mat.rowInd[i]; if (val != 0) { res.colData[index] = val; res.rowInd[index] = row; index++; } } res.colPtr[j] = index; } // write back to the given matrix, note that here mat is just a reference copy of the original matrix mat.rowData = res.rowData; mat.colInd = res.colInd; mat.rowPtr = res.rowPtr; mat.colData = res.colData; mat.rowInd = res.rowInd; mat.colPtr = res.colPtr; } /** * @return a new matrix with shape (rows, cols) with data from the current matrix */ public SparseMatrix reshape(int rows, int cols) { Table<Integer, Integer, Double> data = HashBasedTable.create(); Multimap<Integer, Integer> colMap = HashMultimap.create(); int rowIndex, colIndex; for (int i = 1; i < rowPtr.length; i++) { for (int j = rowPtr[i - 1]; j < rowPtr[i]; j++) { int row = i - 1; int col = colInd[j]; double val = rowData[j]; // (row, col, val) if (val != 0) { int oldIndex = row * numColumns + col; rowIndex = oldIndex / cols; colIndex = oldIndex % cols; data.put(rowIndex, colIndex, val); colMap.put(colIndex, rowIndex); } } } return new SparseMatrix(rows, cols, data, colMap); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(String.format("%d\t%d\t%d\n", new Object[] { numRows, numColumns, size() })); for (MatrixEntry me : this) if (me.get() != 0) sb.append(String.format("%d\t%d\t%f\n", new Object[] { me.row(), me.column(), me.get() })); return sb.toString(); } /** * @return a matrix format string */ public String matString() { StringBuilder sb = new StringBuilder(); sb.append("Dimension: ").append(numRows).append(" x ").append(numColumns).append("\n"); for (int i = 0; i < numRows; i++) { for (int j = 0; j < numColumns; j++) { sb.append(get(i, j)); if (j < numColumns - 1) sb.append("\t"); } sb.append("\n"); } return sb.toString(); } /** * Finds the insertion index of CRS */ private int getCRSIndex(int row, int col) { int i = Arrays.binarySearch(colInd, rowPtr[row], rowPtr[row + 1], col); if (i >= 0 && colInd[i] == col) return i; else throw new IndexOutOfBoundsException( "Entry (" + (row + 1) + ", " + (col + 1) + ") is not in the matrix structure"); } /** * Finds the insertion index of CCS */ private int getCCSIndex(int row, int col) { int i = Arrays.binarySearch(rowInd, colPtr[col], colPtr[col + 1], row); if (i >= 0 && rowInd[i] == row) return i; else throw new IndexOutOfBoundsException( "Entry (" + (row + 1) + ", " + (col + 1) + ") is not in the matrix structure"); } public Iterator<MatrixEntry> iterator() { return new MatrixIterator(); } /** * Entry of a compressed row matrix */ private class SparseMatrixEntry implements MatrixEntry { private int row, cursor; /** * Updates the entry */ public void update(int row, int cursor) { this.row = row; this.cursor = cursor; } public int row() { return row; } public int column() { return colInd[cursor]; } public double get() { return rowData[cursor]; } public void set(double value) { rowData[cursor] = value; } } private class MatrixIterator implements Iterator<MatrixEntry> { private int row, cursor; private SparseMatrixEntry entry = new SparseMatrixEntry(); public MatrixIterator() { // Find first non-empty row nextNonEmptyRow(); } /** * Locates the first non-empty row, starting at the current. After the new row has been found, the cursor is * also updated */ private void nextNonEmptyRow() { while (row < numRows && rowPtr[row] == rowPtr[row + 1]) row++; cursor = rowPtr[row]; } public boolean hasNext() { return cursor < rowData.length; } public MatrixEntry next() { entry.update(row, cursor); // Next position is in the same row if (cursor < rowPtr[row + 1] - 1) cursor++; // Next position is at the following (non-empty) row else { row++; nextNonEmptyRow(); } return entry; } public void remove() { entry.set(0); } } }