librec.data.SparseMatrix.java Source code

Java tutorial

Introduction

Here is the source code for librec.data.SparseMatrix.java

Source

// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see <http://www.gnu.org/licenses/>.
//

package librec.data;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import librec.util.Stats;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import com.google.common.collect.Table.Cell;

/**
 * Data Structure: Sparse Matrix whose implementation is modified from M4J library
 * 
 * <ul>
 * <li><a href="http://netlib.org/linalg/html_templates/node91.html">Compressed Row Storage (CRS)</a></li>
 * <li><a href="http://netlib.org/linalg/html_templates/node92.html">Compressed Col Storage (CCS)</a></li>
 * </ul>
 * 
 * @author guoguibing
 * 
 */
public class SparseMatrix implements Iterable<MatrixEntry>, Serializable {

    private static final long serialVersionUID = 8024536511172609539L;

    // matrix dimension
    protected int numRows, numColumns;

    // Compressed Row Storage (CRS)
    protected double[] rowData;
    protected int[] rowPtr, colInd;

    // Compressed Col Storage (CCS)
    protected double[] colData;
    protected int[] colPtr, rowInd;

    /**
     * Construct a sparse matrix with both CRS and CCS structures
     */
    public SparseMatrix(int rows, int cols, Table<Integer, Integer, ? extends Number> dataTable,
            Multimap<Integer, Integer> colMap) {
        numRows = rows;
        numColumns = cols;

        construct(dataTable, colMap);
    }

    /**
     * Construct a sparse matrix with only CRS structures
     */
    public SparseMatrix(int rows, int cols, Table<Integer, Integer, ? extends Number> dataTable) {
        this(rows, cols, dataTable, null);
    }

    /**
     * Define a sparse matrix without data, only use for {@code transpose} method
     * 
     */
    private SparseMatrix(int rows, int cols) {
        numRows = rows;
        numColumns = cols;
    }

    /**
     * Construct a sparse matrix from another sparse matrix
     * 
     * @param mat
     *            the original sparse matrix
     * @param deap
     *            whether to copy the CCS structures
     */
    public SparseMatrix(SparseMatrix mat) {
        numRows = mat.numRows;
        numColumns = mat.numColumns;

        copyCRS(mat.rowData, mat.rowPtr, mat.colInd);

        copyCCS(mat.colData, mat.colPtr, mat.rowInd);
    }

    private void copyCRS(double[] data, int[] ptr, int[] idx) {
        rowData = new double[data.length];
        for (int i = 0; i < rowData.length; i++)
            rowData[i] = data[i];

        rowPtr = new int[ptr.length];
        for (int i = 0; i < rowPtr.length; i++)
            rowPtr[i] = ptr[i];

        colInd = new int[idx.length];
        for (int i = 0; i < colInd.length; i++)
            colInd[i] = idx[i];
    }

    private void copyCCS(double[] data, int[] ptr, int[] idx) {

        colData = new double[data.length];
        for (int i = 0; i < colData.length; i++)
            colData[i] = data[i];

        colPtr = new int[ptr.length];
        for (int i = 0; i < colPtr.length; i++)
            colPtr[i] = ptr[i];

        rowInd = new int[idx.length];
        for (int i = 0; i < rowInd.length; i++)
            rowInd[i] = idx[i];
    }

    /**
     * Make a deep clone of current matrix
     */
    public SparseMatrix clone() {
        return new SparseMatrix(this);
    }

    /**
     * @return the transpose of current matrix
     */
    public SparseMatrix transpose() {
        SparseMatrix tr = new SparseMatrix(numColumns, numRows);

        tr.copyCRS(this.rowData, this.rowPtr, this.colInd);
        tr.copyCCS(this.colData, this.colPtr, this.rowInd);

        return tr;
    }

    /**
     * @return the row pointers of CRS structure
     */
    public int[] getRowPointers() {
        return rowPtr;
    }

    /**
     * @return the column indices of CCS structure
     */
    public int[] getColumnIndices() {
        return colInd;
    }

    /**
     * @return the cardinary of current matrix
     */
    public int size() {
        int size = 0;

        for (MatrixEntry me : this)
            if (me.get() != 0)
                size++;

        return size;
    }

    /**
     * @return the data table of this matrix as (row, column, value) cells
     */
    public Table<Integer, Integer, Double> getDataTable() {
        Table<Integer, Integer, Double> res = HashBasedTable.create();

        for (MatrixEntry me : this) {
            if (me.get() != 0)
                res.put(me.row(), me.column(), me.get());
        }

        return res;
    }

    /**
     * Construct a sparse matrix
     * 
     * @param dataTable
     *            data table
     * @param columnStructure
     *            column structure
     */
    private void construct(Table<Integer, Integer, ? extends Number> dataTable,
            Multimap<Integer, Integer> columnStructure) {
        int nnz = dataTable.size();

        // CRS
        rowPtr = new int[numRows + 1];
        colInd = new int[nnz];
        rowData = new double[nnz];

        int j = 0;
        for (int i = 1; i <= numRows; ++i) {
            Set<Integer> cols = dataTable.row(i - 1).keySet();
            rowPtr[i] = rowPtr[i - 1] + cols.size();

            for (int col : cols) {
                colInd[j++] = col;
                if (col < 0 || col >= numColumns)
                    throw new IllegalArgumentException(
                            "colInd[" + j + "]=" + col + ", which is not a valid column index");
            }

            Arrays.sort(colInd, rowPtr[i - 1], rowPtr[i]);
        }

        // CCS
        colPtr = new int[numColumns + 1];
        rowInd = new int[nnz];
        colData = new double[nnz];

        j = 0;
        for (int i = 1; i <= numColumns; ++i) {
            // dataTable.col(i-1) is more time-consuming than columnStructure.get(i-1)
            Collection<Integer> rows = columnStructure != null ? columnStructure.get(i - 1)
                    : dataTable.column(i - 1).keySet();
            colPtr[i] = colPtr[i - 1] + rows.size();

            for (int row : rows) {
                rowInd[j++] = row;
                if (row < 0 || row >= numRows)
                    throw new IllegalArgumentException(
                            "rowInd[" + j + "]=" + row + ", which is not a valid row index");
            }

            Arrays.sort(rowInd, colPtr[i - 1], colPtr[i]);
        }

        // set data
        for (Cell<Integer, Integer, ? extends Number> en : dataTable.cellSet()) {
            int row = en.getRowKey();
            int col = en.getColumnKey();
            double val = en.getValue().doubleValue();

            set(row, col, val);
        }
    }

    /**
     * @return number of rows
     */
    public int numRows() {
        return numRows;
    }

    /**
     * @return number of columns
     */
    public int numColumns() {
        return numColumns;
    }

    /**
     * @return referce to the data of current matrix
     */
    public double[] getData() {
        return rowData;
    }

    /**
     * Set a value to entry [row, column]
     * 
     * @param row
     *            row id
     * @param column
     *            column id
     * @param val
     *            value to set
     */
    public void set(int row, int column, double val) {
        int index = getCRSIndex(row, column);
        rowData[index] = val;

        index = getCCSIndex(row, column);
        colData[index] = val;
    }

    /**
     * Add a value to entry [row, column]
     * 
     * @param row
     *            row id
     * @param column
     *            column id
     * @param val
     *            value to add
     */
    public void add(int row, int column, double val) {
        int index = getCRSIndex(row, column);
        rowData[index] += val;

        index = getCCSIndex(row, column);
        colData[index] += val;
    }

    /**
     * Retrieve value at entry [row, column]
     * 
     * @param row
     *            row id
     * @param column
     *            column id
     * @return value at entry [row, column]
     */
    public double get(int row, int column) {

        int index = Arrays.binarySearch(colInd, rowPtr[row], rowPtr[row + 1], column);

        if (index >= 0)
            return rowData[index];
        else
            return 0;
    }

    /**
     * get a row sparse vector of a matrix
     * 
     * @param row
     *            row id
     * @return a sparse vector of {index, value}
     * 
     */
    public SparseVector row(int row) {

        SparseVector sv = new SparseVector(numColumns);

        if (row < numRows) {
            for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) {
                int col = colInd[j];
                double val = get(row, col);
                if (val != 0.0)
                    sv.set(col, val);
            }
        } // return an empty vector if the row does not exist in training matrix

        return sv;
    }

    /**
     * get columns of a specific row where (row, column) entries are non-zero
     * 
     * @param row
     *            row id
     * @return a list of column index
     */
    public List<Integer> getColumns(int row) {
        List<Integer> res = new ArrayList<>();

        if (row < numRows) {
            for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) {
                int col = colInd[j];
                double val = get(row, col);
                if (val != 0.0)
                    res.add(col);
            }
        }

        return res;
    }

    /**
     * create a row cache of a matrix in {row, row-specific vector}
     * 
     * @param cacheSpec
     *            cache specification
     * @return a matrix row cache in {row, row-specific vector}
     */
    public LoadingCache<Integer, SparseVector> rowCache(String cacheSpec) {
        LoadingCache<Integer, SparseVector> cache = CacheBuilder.from(cacheSpec)
                .build(new CacheLoader<Integer, SparseVector>() {

                    @Override
                    public SparseVector load(Integer rowId) throws Exception {
                        return row(rowId);
                    }
                });

        return cache;
    }

    /**
     * create a row cache of a matrix in {row, row-specific columns}
     * 
     * @param cacheSpec
     *            cache specification
     * @return a matrix row cache in {row, row-specific columns}
     */
    public LoadingCache<Integer, List<Integer>> rowColumnsCache(String cacheSpec) {
        LoadingCache<Integer, List<Integer>> cache = CacheBuilder.from(cacheSpec)
                .build(new CacheLoader<Integer, List<Integer>>() {

                    @Override
                    public List<Integer> load(Integer rowId) throws Exception {
                        return getColumns(rowId);
                    }
                });

        return cache;
    }

    /**
     * create a column cache of a matrix
     * 
     * @param cacheSpec
     *            cache specification
     * @return a matrix column cache
     */
    public LoadingCache<Integer, SparseVector> columnCache(String cacheSpec) {
        LoadingCache<Integer, SparseVector> cache = CacheBuilder.from(cacheSpec)
                .build(new CacheLoader<Integer, SparseVector>() {

                    @Override
                    public SparseVector load(Integer columnId) throws Exception {
                        return column(columnId);
                    }
                });

        return cache;
    }

    /**
     * create a row cache of a matrix in {row, row-specific columns}
     * 
     * @param cacheSpec
     *            cache specification
     * @return a matrix row cache in {row, row-specific columns}
     */
    public LoadingCache<Integer, List<Integer>> columnRowsCache(String cacheSpec) {
        LoadingCache<Integer, List<Integer>> cache = CacheBuilder.from(cacheSpec)
                .build(new CacheLoader<Integer, List<Integer>>() {

                    @Override
                    public List<Integer> load(Integer colId) throws Exception {
                        return getRows(colId);
                    }
                });

        return cache;
    }

    /**
     * get a row sparse vector of a matrix
     * 
     * @param row
     *            row id
     * @param except
     *            row id to be excluded
     * @return a sparse vector of {index, value}
     * 
     */
    public SparseVector row(int row, int except) {

        SparseVector sv = new SparseVector(numColumns);

        for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) {
            int col = colInd[j];
            if (col != except) {
                double val = get(row, col);
                if (val != 0.0)
                    sv.set(col, val);
            }
        }
        return sv;
    }

    /**
     * query the size of a specific row
     * 
     * @param row
     *            row id
     * @return the size of non-zero elements of a row
     */
    public int rowSize(int row) {

        int size = 0;
        for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) {
            int col = colInd[j];
            if (get(row, col) != 0.0)
                size++;
        }

        return size;
    }

    /**
     * @return a list of rows which have at least one non-empty entry
     */
    public List<Integer> rows() {
        List<Integer> list = new ArrayList<>();

        for (int row = 0; row < numRows; row++) {
            for (int j = rowPtr[row]; j < rowPtr[row + 1]; j++) {
                int col = colInd[j];
                if (get(row, col) != 0.0) {
                    list.add(row);
                    break;
                }
            }
        }

        return list;
    }

    /**
     * get a col sparse vector of a matrix
     * 
     * @param col
     *            col id
     * @return a sparse vector of {index, value}
     * 
     */
    public SparseVector column(int col) {

        SparseVector sv = new SparseVector(numRows);

        if (col < numColumns) {
            for (int j = colPtr[col]; j < colPtr[col + 1]; j++) {
                int row = rowInd[j];
                double val = get(row, col);
                if (val != 0.0)
                    sv.set(row, val);
            }
        } // return an empty vector if the column does not exist in training
          // matrix

        return sv;
    }

    /**
     * query the size of a specific col
     * 
     * @param col
     *            col id
     * @return the size of non-zero elements of a row
     */
    public int columnSize(int col) {

        int size = 0;

        for (int j = colPtr[col]; j < colPtr[col + 1]; j++) {
            int row = rowInd[j];
            double val = get(row, col);
            if (val != 0.0)
                size++;
        }

        return size;
    }

    /**
     * get rows of a specific column where (row, column) entries are non-zero
     * 
     * @param col
     *            column id
     * @return a list of column index
     */
    public List<Integer> getRows(int col) {

        List<Integer> res = new ArrayList<>();

        if (col < numColumns) {
            for (int j = colPtr[col]; j < colPtr[col + 1]; j++) {
                int row = rowInd[j];
                double val = get(row, col);
                if (val != 0.0)
                    res.add(row);
            }
        }

        return res;
    }

    /**
     * @return a list of columns which have at least one non-empty entry
     */
    public List<Integer> columns() {
        List<Integer> list = new ArrayList<>();

        for (int col = 0; col < numColumns; col++) {
            for (int j = colPtr[col]; j < colPtr[col + 1]; j++) {
                int row = rowInd[j];
                double val = get(row, col);
                if (val != 0.0) {
                    list.add(col);
                    break;
                }
            }
        }

        return list;
    }

    /**
     * @return sum of matrix data
     */
    public double sum() {
        return Stats.sum(rowData);
    }

    /**
     * @return mean of matrix data
     */
    public double mean() {
        return sum() / size();
    }

    /**
     * Normalize the matrix entries to (0, 1) by (x-min)/(max-min)
     * 
     * @param min
     *            minimum value
     * @param max
     *            maximum value
     */
    public void normalize(double min, double max) {
        assert max > min;

        for (MatrixEntry me : this) {
            double entry = me.get();
            if (entry != 0)
                me.set((entry - min) / (max - min));
        }
    }

    /**
     * Normalize the matrix entries to (0, 1) by (x/max)
     * 
     * @param max
     *            maximum value
     */
    public void normalize(double max) {
        normalize(0, max);
    }

    /**
     * Standardize the matrix entries by row- or column-wise z-scores (z=(x-u)/sigma)
     * 
     * @param isByRow
     *            standardize by row if true; otherwise by column
     */
    public void standardize(boolean isByRow) {

        int iters = isByRow ? numRows : numColumns;
        for (int iter = 0; iter < iters; iter++) {
            SparseVector vec = isByRow ? row(iter) : column(iter);

            if (vec.getCount() > 0) {

                double[] data = vec.getData();
                double mu = Stats.mean(data);
                double sigma = Stats.sd(data, mu);

                for (VectorEntry ve : vec) {
                    int idx = ve.index();
                    double val = ve.get();
                    double z = (val - mu) / sigma;

                    if (isByRow)
                        this.set(iter, idx, z);
                    else
                        this.set(idx, iter, z);
                }
            }
        }
    }

    /**
     * remove zero entries of the given matrix
     */
    public static void reshape(SparseMatrix mat) {

        SparseMatrix res = new SparseMatrix(mat.numRows, mat.numColumns);
        int nnz = mat.size();

        // Compressed Row Storage (CRS)
        res.rowData = new double[nnz];
        res.colInd = new int[nnz];
        res.rowPtr = new int[mat.numRows + 1];

        // handle row data
        int index = 0;
        for (int i = 1; i < mat.rowPtr.length; i++) {

            for (int j = mat.rowPtr[i - 1]; j < mat.rowPtr[i]; j++) {
                // row i-1, row 0 always starts with 0

                double val = mat.rowData[j];
                int col = mat.colInd[j];
                if (val != 0) {
                    res.rowData[index] = val;
                    res.colInd[index] = col;

                    index++;
                }
            }
            res.rowPtr[i] = index;

        }

        // Compressed Col Storage (CCS)
        res.colData = new double[nnz];
        res.rowInd = new int[nnz];
        res.colPtr = new int[mat.numColumns + 1];

        // handle column data
        index = 0;
        for (int j = 1; j < mat.colPtr.length; j++) {
            for (int i = mat.colPtr[j - 1]; i < mat.colPtr[j]; i++) {
                // column j-1, index i

                double val = mat.colData[i];
                int row = mat.rowInd[i];
                if (val != 0) {
                    res.colData[index] = val;
                    res.rowInd[index] = row;

                    index++;
                }
            }
            res.colPtr[j] = index;
        }

        // write back to the given matrix, note that here mat is just a reference copy of the original matrix
        mat.rowData = res.rowData;
        mat.colInd = res.colInd;
        mat.rowPtr = res.rowPtr;

        mat.colData = res.colData;
        mat.rowInd = res.rowInd;
        mat.colPtr = res.colPtr;
    }

    /**
     * @return a new matrix with shape (rows, cols) with data from the current matrix
     */
    public SparseMatrix reshape(int rows, int cols) {

        Table<Integer, Integer, Double> data = HashBasedTable.create();
        Multimap<Integer, Integer> colMap = HashMultimap.create();

        int rowIndex, colIndex;
        for (int i = 1; i < rowPtr.length; i++) {
            for (int j = rowPtr[i - 1]; j < rowPtr[i]; j++) {
                int row = i - 1;
                int col = colInd[j];
                double val = rowData[j]; // (row, col, val)

                if (val != 0) {
                    int oldIndex = row * numColumns + col;

                    rowIndex = oldIndex / cols;
                    colIndex = oldIndex % cols;

                    data.put(rowIndex, colIndex, val);
                    colMap.put(colIndex, rowIndex);
                }
            }
        }

        return new SparseMatrix(rows, cols, data, colMap);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("%d\t%d\t%d\n", new Object[] { numRows, numColumns, size() }));

        for (MatrixEntry me : this)
            if (me.get() != 0)
                sb.append(String.format("%d\t%d\t%f\n", new Object[] { me.row(), me.column(), me.get() }));

        return sb.toString();
    }

    /**
     * @return a matrix format string
     */
    public String matString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Dimension: ").append(numRows).append(" x ").append(numColumns).append("\n");

        for (int i = 0; i < numRows; i++) {
            for (int j = 0; j < numColumns; j++) {
                sb.append(get(i, j));
                if (j < numColumns - 1)
                    sb.append("\t");
            }
            sb.append("\n");
        }

        return sb.toString();
    }

    /**
     * Finds the insertion index of CRS
     */
    private int getCRSIndex(int row, int col) {
        int i = Arrays.binarySearch(colInd, rowPtr[row], rowPtr[row + 1], col);

        if (i >= 0 && colInd[i] == col)
            return i;
        else
            throw new IndexOutOfBoundsException(
                    "Entry (" + (row + 1) + ", " + (col + 1) + ") is not in the matrix structure");
    }

    /**
     * Finds the insertion index of CCS
     */
    private int getCCSIndex(int row, int col) {
        int i = Arrays.binarySearch(rowInd, colPtr[col], colPtr[col + 1], row);

        if (i >= 0 && rowInd[i] == row)
            return i;
        else
            throw new IndexOutOfBoundsException(
                    "Entry (" + (row + 1) + ", " + (col + 1) + ") is not in the matrix structure");
    }

    public Iterator<MatrixEntry> iterator() {
        return new MatrixIterator();
    }

    /**
     * Entry of a compressed row matrix
     */
    private class SparseMatrixEntry implements MatrixEntry {

        private int row, cursor;

        /**
         * Updates the entry
         */
        public void update(int row, int cursor) {
            this.row = row;
            this.cursor = cursor;
        }

        public int row() {
            return row;
        }

        public int column() {
            return colInd[cursor];
        }

        public double get() {
            return rowData[cursor];
        }

        public void set(double value) {
            rowData[cursor] = value;
        }
    }

    private class MatrixIterator implements Iterator<MatrixEntry> {

        private int row, cursor;

        private SparseMatrixEntry entry = new SparseMatrixEntry();

        public MatrixIterator() {
            // Find first non-empty row
            nextNonEmptyRow();
        }

        /**
         * Locates the first non-empty row, starting at the current. After the new row has been found, the cursor is
         * also updated
         */
        private void nextNonEmptyRow() {
            while (row < numRows && rowPtr[row] == rowPtr[row + 1])
                row++;
            cursor = rowPtr[row];
        }

        public boolean hasNext() {
            return cursor < rowData.length;
        }

        public MatrixEntry next() {
            entry.update(row, cursor);

            // Next position is in the same row
            if (cursor < rowPtr[row + 1] - 1)
                cursor++;

            // Next position is at the following (non-empty) row
            else {
                row++;
                nextNonEmptyRow();
            }

            return entry;
        }

        public void remove() {
            entry.set(0);
        }

    }
}