com.davidbracewell.math.linear.MatrixMath.java Source code

Java tutorial

Introduction

Here is the source code for com.davidbracewell.math.linear.MatrixMath.java

Source

/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.math.linear;

import com.davidbracewell.collection.CollectionUtils;
import com.davidbracewell.math.DoubleEntry;
import com.davidbracewell.math.functions.BinaryFunction;
import com.davidbracewell.math.functions.Functions;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

/**
 * The type Matrix math.
 * @author David B. Bracewell
 */
public class MatrixMath {

    private static Matrix createMatrix(Matrix m1, Matrix m2, int M, int N) {
        if (m1.isSparse() || (m2 != null && m2.isSparse())) {
            return new SparseMatrix(M, N);
        } else {
            return new JBlasMatrix(M, N);
        }
    }

    private static Matrix increment(Matrix src, Matrix target, double amount) {
        for (int row = 0; row < src.rowDimension(); row++) {
            for (int col = 0; col < src.columnDimension(); col++) {
                target.set(row, col, src.get(row, col) + amount);
            }
        }
        return target;
    }

    /**
     * Increment matrix.
     *
     * @param m1 the m 1
     * @param amount the amount
     * @return the matrix
     */
    public static Matrix increment(Matrix m1, double amount) {
        return increment(Preconditions.checkNotNull(m1),
                createMatrix(m1, null, m1.rowDimension(), m1.columnDimension()), amount);
    }

    /**
     * Increment self.
     *
     * @param m1 the m 1
     * @param amount the amount
     * @return the matrix
     */
    public static Matrix incrementSelf(Matrix m1, double amount) {
        return increment(Preconditions.checkNotNull(m1), m1, amount);
    }

    private static Matrix adjust(Matrix m1, Matrix m2, Matrix target, BinaryFunction function) {
        ForkJoinPool pool = new ForkJoinPool();
        pool.invoke(new MatrixRowOperation(m1, m2, target, -1, function));
        return target;
    }

    /**
     * Adjust self.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @param function the function
     * @return the matrix
     */
    public static Matrix adjustSelf(Matrix m1, Matrix m2, BinaryFunction function) {
        Preconditions.checkNotNull(m1);
        Preconditions.checkNotNull(m2);
        Preconditions.checkNotNull(function);
        Preconditions.checkArgument(m1.rowDimension() == m2.rowDimension(),
                "Dimension mismatch [" + m1.rowDimension() + " != " + m2.rowDimension() + "]");
        Preconditions.checkArgument(m1.columnDimension() == m2.columnDimension(),
                "Dimension mismatch [" + m1.columnDimension() + " != " + m2.columnDimension() + "]");
        return adjust(m1, m2, m1, function);
    }

    /**
     * Adjust matrix.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @param function the function
     * @return the matrix
     */
    public static Matrix adjust(Matrix m1, Matrix m2, BinaryFunction function) {
        Preconditions.checkNotNull(m1);
        Preconditions.checkNotNull(m2);
        Preconditions.checkNotNull(function);
        Preconditions.checkArgument(m1.rowDimension() == m2.rowDimension(),
                "Dimension mismatch [" + m1.rowDimension() + " != " + m2.rowDimension() + "]");
        Preconditions.checkArgument(m1.columnDimension() == m2.columnDimension(),
                "Dimension mismatch [" + m1.columnDimension() + " != " + m2.columnDimension() + "]");
        return adjust(m1, m2, createMatrix(m1, m2, m1.rowDimension(), m1.columnDimension()), function);
    }

    /**
     * Add matrix.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @return the matrix
     */
    public static Matrix add(Matrix m1, Matrix m2) {
        return adjust(m1, m2, Functions.ADD);
    }

    /**
     * Subtract matrix.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @return the matrix
     */
    public static Matrix subtract(Matrix m1, Matrix m2) {
        return adjust(m1, m2, Functions.SUBTRACT);
    }

    /**
     * Ebe multiply.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @return the matrix
     */
    public static Matrix ebeMultiply(Matrix m1, Matrix m2) {
        return adjust(m1, m2, Functions.MULTIPLY);
    }

    /**
     * Ebe divide.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @return the matrix
     */
    public static Matrix ebeDivide(Matrix m1, Matrix m2) {
        return adjust(m1, m2, Functions.DIVIDE);
    }

    /**
     * Multiply matrix.
     *
     * @param m1 the m 1
     * @param m2 the m 2
     * @return the matrix
     */
    public static Matrix multiply(Matrix m1, Matrix m2) {
        Preconditions.checkNotNull(m1);
        Preconditions.checkNotNull(m2);
        Preconditions.checkArgument(m1.rowDimension() == m2.columnDimension(),
                "Dimension mismatch [" + m1.rowDimension() + " != " + m2.columnDimension() + "]");
        Matrix result = createMatrix(m1, m2, m1.rowDimension(), m2.columnDimension());
        ForkJoinPool pool = new ForkJoinPool();
        pool.invoke(new MatrixMultiplier(m1, m2, result, -1));
        return result;
    }

    /**
     * Random matrix.
     *
     * @param M the m
     * @param N the n
     * @return the matrix
     */
    public static Matrix randomMatrix(int M, int N) {
        Matrix m = new JBlasMatrix(M, N);
        //new SparseMatrix(M, N);
        Random random = new Random();
        for (int row = 0; row < M; row++) {
            for (int col = 0; col < N; col++) {
                double val = random.nextDouble() * 100;
                if (val > 0) {
                    m.set(row, col, val);
                }
            }
        }

        return m;
    }

    /**
     * The entry point of application.
     *
     * @param args the input arguments
     * @throws Exception the exception
     */
    public static void main(String[] args) throws Exception {
        Matrix m1 = randomMatrix(10000, 10000);
        Matrix m2 = randomMatrix(10000, 10000);

        Stopwatch sw = Stopwatch.createStarted();
        Matrix m3 = m1.multiply(m2);
        sw.stop();
        System.out.println(sw);

        //    System.out.println(m3);
        //
        //    m1 = new SparseMatrix(2, 3);
        //    m1.set(0, 0, 1);
        //    m1.set(0, 1, 2);
        //    m1.set(0, 2, 3);
        //    m1.set(1, 0, 4);
        //    m1.set(1, 1, 5);
        //    m1.set(1, 2, 6);
        //
        //    m2 = new SparseMatrix(3, 2);
        //    m2.set(0, 0, 7);
        //    m2.set(0, 1, 8);
        //    m2.set(1, 0, 9);
        //    m2.set(1, 1, 10);
        //    m2.set(2, 0, 11);
        //    m2.set(2, 1, 12);
        //
        //    m3 = m1.multiply(m2);
        //    System.out.println(m3);

    }

    private static class MatrixMultiplier extends RecursiveAction {
        private static final long serialVersionUID = -8570108050068635908L;
        final Matrix a, b, c;
        final int row;

        private MatrixMultiplier(Matrix a, Matrix b, Matrix c, int row) {
            this.a = a;
            this.b = b;
            this.c = c;
            this.row = row;
        }

        @Override
        protected void compute() {
            if (row == -1) {
                List<MatrixMultiplier> tasks = Lists.newArrayList();
                for (int row = 0; row < a.rowDimension(); row++) {
                    tasks.add(new MatrixMultiplier(a, b, c, row));
                }
                invokeAll(tasks);
            } else {
                doCalculation();
            }
        }

        void doCalculation() {
            for (int col = 0; col < b.columnDimension(); col++) {
                double sum = 0d;
                for (DoubleEntry entry : CollectionUtils.asIterable(a.getRow(row).nonZeroIterator())) {
                    sum += entry.value * b.get(entry.index, col);
                }
                c.set(row, col, sum);
            }
        }

    }

    private static class MatrixRowOperation extends RecursiveAction {
        private static final long serialVersionUID = -8570108050068635908L;
        final Matrix a, b, c;
        final int row;
        final BinaryFunction function;

        private MatrixRowOperation(Matrix a, Matrix b, Matrix c, int row, BinaryFunction function) {
            this.a = a;
            this.b = b;
            this.c = c;
            this.row = row;
            this.function = function;
        }

        void doCalculation() {
            for (DoubleEntry entry : CollectionUtils.asIterable(a.getRow(row).iterator())) {
                c.set(row, entry.index, function.value(entry.value, b.get(row, entry.index)));
            }
        }

        @Override
        protected void compute() {
            if (row == -1) {
                List<MatrixRowOperation> tasks = Lists.newArrayList();
                for (int row = 0; row < a.rowDimension(); row++) {
                    tasks.add(new MatrixRowOperation(a, b, c, row, function));
                }
                invokeAll(tasks);
            } else {
                doCalculation();
            }
        }

    }

}//END OF MatrixMath