gov.nasa.jpl.mudrod.utils.MatrixUtil.java Source code

Java tutorial

Introduction

Here is the source code for gov.nasa.jpl.mudrod.utils.MatrixUtil.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package gov.nasa.jpl.mudrod.utils;

import gov.nasa.jpl.mudrod.driver.SparkDriver;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.Optional;
import org.apache.spark.api.java.function.*;
import org.apache.spark.mllib.feature.IDF;
import org.apache.spark.mllib.feature.IDFModel;
import org.apache.spark.mllib.linalg.*;
import org.apache.spark.mllib.linalg.distributed.IndexedRow;
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import scala.Tuple2;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;

/**
 * Matrix utility tool
 */
public class MatrixUtil {

    private MatrixUtil() {
    }

    /**
     * buildSVDMatrix: Generate SVD matrix from TF-IDF matrix. Please make sure
     * the TF-IDF matrix has been already built from the original documents.
     *
     * @param tfidfMatrix,
     *          each row is a term and each column is a document name and each
     *          cell is the TF-IDF value of the term in the corresponding
     *          document.
     * @param dimension
     *          Column number of the SVD matrix
     * @return RowMatrix, each row is a term and each column is a dimension in the
     *         feature space, each cell is value of the term in the corresponding
     *         dimension.
     */
    public static RowMatrix buildSVDMatrix(RowMatrix tfidfMatrix, int dimension) {
        int matrixCol = (int) tfidfMatrix.numCols();
        if (matrixCol < dimension) {
            dimension = matrixCol;
        }

        SingularValueDecomposition<RowMatrix, Matrix> svd = tfidfMatrix.computeSVD(dimension, true, 1.0E-9d);
        RowMatrix u = svd.U();
        Vector s = svd.s();
        return u.multiply(Matrices.diag(s));
    }

    /**
     * buildSVDMatrix: Generate SVD matrix from Vector RDD.
     *
     * @param vecRDD
     *          vectors of terms in feature space
     * @param dimension
     *          Column number of the SVD matrix
     * @return RowMatrix, each row is a term and each column is a dimension in the
     *         feature space, each cell is value of the term in the corresponding
     *         dimension.
     */
    public static RowMatrix buildSVDMatrix(JavaRDD<Vector> vecRDD, int dimension) {
        RowMatrix tfidfMatrix = new RowMatrix(vecRDD.rdd());
        SingularValueDecomposition<RowMatrix, Matrix> svd = tfidfMatrix.computeSVD(dimension, true, 1.0E-9d);
        RowMatrix u = svd.U();
        Vector s = svd.s();
        return u.multiply(Matrices.diag(s));
    }

    /**
     * Create TF-IDF matrix from word-doc matrix.
     *
     * @param wordDocMatrix,
     *          each row is a term, each column is a document name and each cell
     *          is number of the term in the corresponding document.
     * @return RowMatrix, each row is a term and each column is a document name
     *         and each cell is the TF-IDF value of the term in the corresponding
     *         document.
     */
    public static RowMatrix createTFIDFMatrix(RowMatrix wordDocMatrix) {
        JavaRDD<Vector> newcountRDD = wordDocMatrix.rows().toJavaRDD();
        IDFModel idfModel = new IDF().fit(newcountRDD);
        JavaRDD<Vector> idf = idfModel.transform(newcountRDD);
        return new RowMatrix(idf.rdd());
    }

    /**
     * Create matrix from doc-terms JavaPairRDD.
     *
     * @param uniqueDocRDD
     *          doc-terms JavaPairRDD, in which each key is a doc name, and value
     *          is term list extracted from that doc
     * @return LabeledRowMatrix {@link LabeledRowMatrix}
     */
    public static LabeledRowMatrix createWordDocMatrix(JavaPairRDD<String, List<String>> uniqueDocRDD) {
        // Index documents with unique IDs
        JavaPairRDD<List<String>, Long> corpus = uniqueDocRDD.values().zipWithIndex();
        // cal word-doc numbers
        JavaPairRDD<Tuple2<String, Long>, Double> worddocNumRDD = corpus
                .flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>, Long>, Tuple2<String, Long>, Double>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Iterator<Tuple2<Tuple2<String, Long>, Double>> call(Tuple2<List<String>, Long> docwords)
                            throws Exception {
                        List<Tuple2<Tuple2<String, Long>, Double>> pairs = new ArrayList<>();
                        List<String> words = docwords._1;
                        int n = words.size();
                        for (int i = 0; i < n; i++) {
                            Tuple2<String, Long> worddoc = new Tuple2<>(words.get(i), docwords._2);
                            pairs.add(new Tuple2<Tuple2<String, Long>, Double>(worddoc, 1.0));
                        }
                        return pairs.iterator();
                    }
                }).reduceByKey(new Function2<Double, Double, Double>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Double call(Double first, Double second) throws Exception {
                        return first + second;
                    }
                });
        // cal word doc-numbers
        JavaPairRDD<String, Tuple2<List<Long>, List<Double>>> wordDocnumRDD = worddocNumRDD.mapToPair(
                new PairFunction<Tuple2<Tuple2<String, Long>, Double>, String, Tuple2<List<Long>, List<Double>>>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<String, Tuple2<List<Long>, List<Double>>> call(
                            Tuple2<Tuple2<String, Long>, Double> worddocNum) throws Exception {
                        List<Long> docs = new ArrayList<>();
                        docs.add(worddocNum._1._2);
                        List<Double> nums = new ArrayList<>();
                        nums.add(worddocNum._2);
                        Tuple2<List<Long>, List<Double>> docmums = new Tuple2<>(docs, nums);
                        return new Tuple2<>(worddocNum._1._1, docmums);
                    }
                });
        // trans to vector
        final int corporsize = (int) uniqueDocRDD.keys().count();
        JavaPairRDD<String, Vector> wordVectorRDD = wordDocnumRDD.reduceByKey(
                new Function2<Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<List<Long>, List<Double>> call(Tuple2<List<Long>, List<Double>> arg0,
                            Tuple2<List<Long>, List<Double>> arg1) throws Exception {
                        arg0._1.addAll(arg1._1);
                        arg0._2.addAll(arg1._2);
                        return new Tuple2<>(arg0._1, arg0._2);
                    }
                }).mapToPair(new PairFunction<Tuple2<String, Tuple2<List<Long>, List<Double>>>, String, Vector>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<String, Vector> call(Tuple2<String, Tuple2<List<Long>, List<Double>>> arg0)
                            throws Exception {
                        int docsize = arg0._2._1.size();
                        int[] intArray = new int[docsize];
                        double[] doubleArray = new double[docsize];
                        for (int i = 0; i < docsize; i++) {
                            intArray[i] = arg0._2._1.get(i).intValue();
                            doubleArray[i] = arg0._2._2.get(i).intValue();
                        }
                        Vector sv = Vectors.sparse(corporsize, intArray, doubleArray);
                        return new Tuple2<>(arg0._1, sv);
                    }
                });

        RowMatrix wordDocMatrix = new RowMatrix(wordVectorRDD.values().rdd());

        LabeledRowMatrix labeledRowMatrix = new LabeledRowMatrix();
        labeledRowMatrix.rowMatrix = wordDocMatrix;
        labeledRowMatrix.rowkeys = wordVectorRDD.keys().collect();
        labeledRowMatrix.colkeys = uniqueDocRDD.keys().collect();
        return labeledRowMatrix;
    }

    public static LabeledRowMatrix createDocWordMatrix(JavaPairRDD<String, List<String>> uniqueDocRDD,
            JavaSparkContext sc) {
        // Index word with unique IDs
        JavaPairRDD<String, Long> wordIDRDD = uniqueDocRDD.values()
                .flatMap(new FlatMapFunction<List<String>, String>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Iterator<String> call(List<String> arg0) throws Exception {
                        return arg0.iterator();
                    }
                }).distinct().zipWithIndex();

        //
        JavaPairRDD<Tuple2<String, String>, Double> docwordNumRDD = uniqueDocRDD.flatMapToPair(
                new PairFlatMapFunction<Tuple2<String, List<String>>, Tuple2<String, String>, Double>() {

                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Iterator<Tuple2<Tuple2<String, String>, Double>> call(
                            Tuple2<String, List<String>> docwords) throws Exception {
                        List<Tuple2<Tuple2<String, String>, Double>> pairs = new ArrayList<>();
                        List<String> words = docwords._2;
                        int n = words.size();
                        for (int i = 0; i < n; i++) {
                            Tuple2<String, String> worddoc = new Tuple2<>(docwords._1, words.get(i));
                            pairs.add(new Tuple2<Tuple2<String, String>, Double>(worddoc, 1.0));
                        }
                        return pairs.iterator();
                    }
                }).reduceByKey(new Function2<Double, Double, Double>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Double call(Double first, Double second) throws Exception {
                        return first + second;
                    }
                });

        //
        JavaPairRDD<String, Tuple2<String, Double>> wordDocnumRDD = docwordNumRDD.mapToPair(
                new PairFunction<Tuple2<Tuple2<String, String>, Double>, String, Tuple2<String, Double>>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<String, Tuple2<String, Double>> call(Tuple2<Tuple2<String, String>, Double> arg0)
                            throws Exception {

                        Tuple2<String, Double> wordmums = new Tuple2<>(arg0._1._1, arg0._2);
                        return new Tuple2<>(arg0._1._2, wordmums);
                    }
                });

        //

        JavaPairRDD<String, Tuple2<Tuple2<String, Double>, Optional<Long>>> testRDD = wordDocnumRDD
                .leftOuterJoin(wordIDRDD);

        int wordsize = (int) wordIDRDD.count();
        JavaPairRDD<String, Vector> docVectorRDD = testRDD.mapToPair(
                new PairFunction<Tuple2<String, Tuple2<Tuple2<String, Double>, Optional<Long>>>, String, Tuple2<List<Long>, List<Double>>>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<String, Tuple2<List<Long>, List<Double>>> call(
                            Tuple2<String, Tuple2<Tuple2<String, Double>, Optional<Long>>> arg0) throws Exception {
                        Optional<Long> oid = arg0._2._2;
                        Long wordId = (long) 0;
                        if (oid.isPresent()) {
                            wordId = oid.get();
                        }

                        List<Long> word = new ArrayList<>();
                        word.add(wordId);

                        List<Double> count = new ArrayList<>();
                        count.add(arg0._2._1._2);

                        Tuple2<List<Long>, List<Double>> wordcount = new Tuple2<>(word, count);

                        return new Tuple2<>(arg0._2._1._1, wordcount);
                    }

                }).reduceByKey(
                        new Function2<Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>, Tuple2<List<Long>, List<Double>>>() {
                            /**
                             *
                             */
                            private static final long serialVersionUID = 1L;

                            @Override
                            public Tuple2<List<Long>, List<Double>> call(Tuple2<List<Long>, List<Double>> arg0,
                                    Tuple2<List<Long>, List<Double>> arg1) throws Exception {
                                arg0._1.addAll(arg1._1);
                                arg0._2.addAll(arg1._2);
                                return new Tuple2<>(arg0._1, arg0._2);
                            }
                        })
                .mapToPair(new PairFunction<Tuple2<String, Tuple2<List<Long>, List<Double>>>, String, Vector>() {
                    /**
                     *
                     */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<String, Vector> call(Tuple2<String, Tuple2<List<Long>, List<Double>>> arg0)
                            throws Exception {
                        int docsize = arg0._2._1.size();
                        int[] intArray = new int[docsize];
                        double[] doubleArray = new double[docsize];
                        for (int i = 0; i < docsize; i++) {
                            intArray[i] = arg0._2._1.get(i).intValue();
                            doubleArray[i] = arg0._2._2.get(i).intValue();
                        }
                        Vector sv = Vectors.sparse(wordsize, intArray, doubleArray);
                        return new Tuple2<>(arg0._1, sv);
                    }
                });

        RowMatrix docwordMatrix = new RowMatrix(docVectorRDD.values().rdd());

        LabeledRowMatrix labeledRowMatrix = new LabeledRowMatrix();
        labeledRowMatrix.rowMatrix = docwordMatrix;
        labeledRowMatrix.rowkeys = docVectorRDD.keys().collect();
        labeledRowMatrix.colkeys = wordIDRDD.keys().collect();

        return labeledRowMatrix;
    }

    /**
     * loadVectorFromCSV: Load term vector from csv file.
     *
     * @param spark
     *          spark instance
     * @param csvFileName
     *          csv matrix file
     * @param skipNum
     *          the numbers of rows which should be skipped Ignore the top skip
     *          number rows of the csv file
     * @return JavaPairRDD, each key is a term, and value is the vector of the
     *         term in feature space.
     */
    public static JavaPairRDD<String, Vector> loadVectorFromCSV(SparkDriver spark, String csvFileName,
            int skipNum) {
        // skip the first line (header), important!
        JavaRDD<String> importRDD = spark.sc.textFile(csvFileName);
        JavaPairRDD<String, Long> importIdRDD = importRDD.zipWithIndex()
                .filter(new Function<Tuple2<String, Long>, Boolean>() {
                    /** */
                    private static final long serialVersionUID = 1L;

                    @Override
                    public Boolean call(Tuple2<String, Long> v1) throws Exception {
                        if (v1._2 > (skipNum - 1)) {
                            return true;
                        }
                        return false;
                    }
                });

        if (importIdRDD.count() == 0) {
            return null;
        }

        return importIdRDD.mapToPair(new PairFunction<Tuple2<String, Long>, String, Vector>() {
            /** */
            private static final long serialVersionUID = 1L;

            @Override
            public Tuple2<String, Vector> call(Tuple2<String, Long> t) throws Exception {
                String[] fields = t._1.split(",");
                String word = fields[0];
                int fieldsize = fields.length;
                int nStart = 1;
                int nEnd = fieldsize - 1;
                if (fieldsize < 2) {
                    nStart = 0;
                    nEnd = 0;
                }
                String[] numfields = Arrays.copyOfRange(fields, nStart, nEnd);

                double[] nums = Stream.of(numfields).mapToDouble(Double::parseDouble).toArray();
                Vector vec = Vectors.dense(nums);
                return new Tuple2<>(word, vec);
            }
        });
    }

    /**
     * Convert vectorRDD to indexed row matrix.
     *
     * @param vecs
     *          Vector RDD
     * @return IndexedRowMatrix
     */
    public static IndexedRowMatrix buildIndexRowMatrix(JavaRDD<Vector> vecs) {
        JavaRDD<IndexedRow> indexrows = vecs.zipWithIndex().map(new Function<Tuple2<Vector, Long>, IndexedRow>() {
            /**
             *
             */
            private static final long serialVersionUID = 1L;

            @Override
            public IndexedRow call(Tuple2<Vector, Long> docId) {
                return new IndexedRow(docId._2, docId._1);
            }
        });
        return new IndexedRowMatrix(indexrows.rdd());
    }

    /**
     * Transpose matrix
     *
     * @param indexedMatrix
     *          spark indexed matrix
     * @return rowmatrix, each row is corresponding to the column in the original
     *         matrix and vice versa
     */
    public static RowMatrix transposeMatrix(IndexedRowMatrix indexedMatrix) {
        return indexedMatrix.toCoordinateMatrix().transpose().toRowMatrix();
    }

    /**
     * Output matrix to a CSV file.
     *
     * @param matrix
     *          spark row matrix
     * @param rowKeys
     *          matrix row names
     * @param colKeys
     *          matrix coloum names
     * @param fileName
     *          csv file name
     */
    public static void exportToCSV(RowMatrix matrix, List<String> rowKeys, List<String> colKeys, String fileName) {

        if (matrix.rows().isEmpty()) {
            return;
        }

        int rownum = (int) matrix.numRows();
        int colnum = (int) matrix.numCols();
        List<Vector> rows = matrix.rows().toJavaRDD().collect();

        File file = new File(fileName);
        if (file.exists()) {
            file.delete();
        }
        try {
            file.createNewFile();
            FileWriter fw = new FileWriter(file.getAbsoluteFile());
            BufferedWriter bw = new BufferedWriter(fw);
            String coltitle = " Num" + ",";
            for (int j = 0; j < colnum; j++) {
                coltitle += "\"" + colKeys.get(j) + "\",";
            }
            coltitle = coltitle.substring(0, coltitle.length() - 1);
            bw.write(coltitle + "\n");

            for (int i = 0; i < rownum; i++) {
                double[] rowvlaue = rows.get(i).toArray();
                String row = rowKeys.get(i) + ",";
                for (int j = 0; j < colnum; j++) {
                    row += rowvlaue[j] + ",";
                }
                row = row.substring(0, row.length() - 1);
                bw.write(row + "\n");
            }

            bw.close();

        } catch (IOException e) {
            e.printStackTrace();

        }
    }
}