org.hillview.sketches.JLProjection.java Source code

Java tutorial

Introduction

Here is the source code for org.hillview.sketches.JLProjection.java

Source

/*
 * Copyright (c) 2017 VMware Inc. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.hillview.sketches;

import org.apache.commons.lang.ArrayUtils;

import javax.annotation.Nullable;
import java.security.InvalidParameterException;
import java.util.LinkedHashMap;

/** Data structure used to store the results of a Johnson-Lindenstrauss (JL) sketch.
 * It contains a vector of doubles for each column, and some other information that can be used for
 * normalization. It implements the ICorrelation interface and can be used for computing norms,
 * inner products etc, but it is currently rather slow compared to sampling based methods.
 */
public class JLProjection implements ICorrelation {
    /**
     * The JL Sketch stores for every column, a vector of doubles of dimension lowDim, which is the
     * sketch of that column. The sketch is computed by multiplying th column (viewed as a vector of
     * doubles) by a random matrix of {-1, 1} values.
     */
    public final LinkedHashMap<String, double[]> hMap;
    /**
     * The dimension we are projecting down to. The error in estimating the norm is proportional to
     * 1/sqrt(lowDim).
     */
    private final int lowDim;
    /**
     * The length of each column which is the dimension of each column vector.
     */
    public int highDim;
    /**
     * The list of columns we want to sketch. Each column should be of type Int/Double.
     */
    public final String[] colNames;
    /**
     * The matrix of pairwise correlations (see ICorrelation for exact definition of correlation).
     */
    @Nullable
    private double[][] corrMatrix;

    public JLProjection(String[] colNames, int lowDim) {
        if (lowDim <= 0)
            throw new InvalidParameterException("LowDim has to be positive.");
        this.lowDim = lowDim;
        this.colNames = colNames;
        this.hMap = new LinkedHashMap<String, double[]>();
        for (String s : colNames)
            this.hMap.put(s, new double[this.lowDim]);
        this.highDim = 0;
        this.corrMatrix = null;
    }

    public void update(String s, int j, double val) {
        this.hMap.get(s)[j] += val;
    }

    public double get(String s, int j) {
        return this.hMap.get(s)[j];
    }

    public void scale(double f) {
        for (String s : this.hMap.keySet()) {
            for (int i = 0; i < this.lowDim; i++)
                this.hMap.get(s)[i] *= f;
        }
    }

    public double getNorm(String s) {
        if (!this.hMap.containsKey(s))
            throw new InvalidParameterException("No sketch found for column: " + s);
        if (this.highDim <= 0)
            throw new InvalidParameterException("Dimension must be positive.");
        double sum = 0;
        double a[] = this.hMap.get(s);
        for (int i = 0; i < this.lowDim; i++)
            sum += Math.pow(a[i], 2);
        return Math.sqrt(sum / (this.lowDim * this.highDim));
    }

    public double getInnerProduct(String s, String t) {
        if (!this.hMap.containsKey(s))
            throw new InvalidParameterException("No sketch found for column: " + s);
        if (!this.hMap.containsKey(t))
            throw new InvalidParameterException("No sketch found for column: " + t);
        if (this.highDim <= 0)
            throw new InvalidParameterException("Dimension must be positive.");
        double sum = 0;
        double a[] = this.hMap.get(s);
        double b[] = this.hMap.get(t);
        for (int i = 0; i < this.lowDim; i++)
            sum += a[i] * b[i];
        return (sum / (this.lowDim * this.highDim));
    }

    @Override
    public double[][] getCorrelationMatrix() {
        if (this.corrMatrix == null) {
            int d = this.colNames.length;
            this.corrMatrix = new double[d][d];
            for (int i = 0; i < d; i++)
                for (int j = i; j < d; j++) {
                    double sum = 0, first = 0, second = 0;
                    double a[] = this.hMap.get(this.colNames[i]);
                    double b[] = this.hMap.get(this.colNames[j]);
                    for (int k = 0; k < this.lowDim; k++) {
                        sum += a[k] * b[k];
                        first += Math.pow(a[k], 2);
                        second += Math.pow(b[k], 2);
                    }
                    if ((first == 0) || (second == 0))
                        this.corrMatrix[i][j] = 0;
                    else
                        this.corrMatrix[i][j] = sum / Math.sqrt(first * second);
                    this.corrMatrix[j][i] = this.corrMatrix[i][j];
                }
        }
        return this.corrMatrix;
    }

    @Override
    public double getCorrelation(String s, String t) {
        int sIndex = ArrayUtils.indexOf(this.colNames, s);
        int tIndex = ArrayUtils.indexOf(this.colNames, t);
        return this.getCorrelationMatrix()[sIndex][tIndex];
    }

    @Override
    public double[] getCorrelationWith(String s) {
        int sIndex = ArrayUtils.indexOf(this.colNames, s);
        return this.getCorrelationMatrix()[sIndex];
    }
}