Java tutorial
/* * (c) 2005 David B. Bracewell * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.davidbracewell.math.linear; import com.davidbracewell.collection.CollectionUtils; import com.davidbracewell.math.DoubleEntry; import com.davidbracewell.math.functions.BinaryFunction; import com.davidbracewell.math.functions.Functions; import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.collect.Lists; import java.util.List; import java.util.Random; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveAction; /** * The type Matrix math. * @author David B. Bracewell */ public class MatrixMath { private static Matrix createMatrix(Matrix m1, Matrix m2, int M, int N) { if (m1.isSparse() || (m2 != null && m2.isSparse())) { return new SparseMatrix(M, N); } else { return new JBlasMatrix(M, N); } } private static Matrix increment(Matrix src, Matrix target, double amount) { for (int row = 0; row < src.rowDimension(); row++) { for (int col = 0; col < src.columnDimension(); col++) { target.set(row, col, src.get(row, col) + amount); } } return target; } /** * Increment matrix. * * @param m1 the m 1 * @param amount the amount * @return the matrix */ public static Matrix increment(Matrix m1, double amount) { return increment(Preconditions.checkNotNull(m1), createMatrix(m1, null, m1.rowDimension(), m1.columnDimension()), amount); } /** * Increment self. * * @param m1 the m 1 * @param amount the amount * @return the matrix */ public static Matrix incrementSelf(Matrix m1, double amount) { return increment(Preconditions.checkNotNull(m1), m1, amount); } private static Matrix adjust(Matrix m1, Matrix m2, Matrix target, BinaryFunction function) { ForkJoinPool pool = new ForkJoinPool(); pool.invoke(new MatrixRowOperation(m1, m2, target, -1, function)); return target; } /** * Adjust self. * * @param m1 the m 1 * @param m2 the m 2 * @param function the function * @return the matrix */ public static Matrix adjustSelf(Matrix m1, Matrix m2, BinaryFunction function) { Preconditions.checkNotNull(m1); Preconditions.checkNotNull(m2); Preconditions.checkNotNull(function); Preconditions.checkArgument(m1.rowDimension() == m2.rowDimension(), "Dimension mismatch [" + m1.rowDimension() + " != " + m2.rowDimension() + "]"); Preconditions.checkArgument(m1.columnDimension() == m2.columnDimension(), "Dimension mismatch [" + m1.columnDimension() + " != " + m2.columnDimension() + "]"); return adjust(m1, m2, m1, function); } /** * Adjust matrix. * * @param m1 the m 1 * @param m2 the m 2 * @param function the function * @return the matrix */ public static Matrix adjust(Matrix m1, Matrix m2, BinaryFunction function) { Preconditions.checkNotNull(m1); Preconditions.checkNotNull(m2); Preconditions.checkNotNull(function); Preconditions.checkArgument(m1.rowDimension() == m2.rowDimension(), "Dimension mismatch [" + m1.rowDimension() + " != " + m2.rowDimension() + "]"); Preconditions.checkArgument(m1.columnDimension() == m2.columnDimension(), "Dimension mismatch [" + m1.columnDimension() + " != " + m2.columnDimension() + "]"); return adjust(m1, m2, createMatrix(m1, m2, m1.rowDimension(), m1.columnDimension()), function); } /** * Add matrix. * * @param m1 the m 1 * @param m2 the m 2 * @return the matrix */ public static Matrix add(Matrix m1, Matrix m2) { return adjust(m1, m2, Functions.ADD); } /** * Subtract matrix. * * @param m1 the m 1 * @param m2 the m 2 * @return the matrix */ public static Matrix subtract(Matrix m1, Matrix m2) { return adjust(m1, m2, Functions.SUBTRACT); } /** * Ebe multiply. * * @param m1 the m 1 * @param m2 the m 2 * @return the matrix */ public static Matrix ebeMultiply(Matrix m1, Matrix m2) { return adjust(m1, m2, Functions.MULTIPLY); } /** * Ebe divide. * * @param m1 the m 1 * @param m2 the m 2 * @return the matrix */ public static Matrix ebeDivide(Matrix m1, Matrix m2) { return adjust(m1, m2, Functions.DIVIDE); } /** * Multiply matrix. * * @param m1 the m 1 * @param m2 the m 2 * @return the matrix */ public static Matrix multiply(Matrix m1, Matrix m2) { Preconditions.checkNotNull(m1); Preconditions.checkNotNull(m2); Preconditions.checkArgument(m1.rowDimension() == m2.columnDimension(), "Dimension mismatch [" + m1.rowDimension() + " != " + m2.columnDimension() + "]"); Matrix result = createMatrix(m1, m2, m1.rowDimension(), m2.columnDimension()); ForkJoinPool pool = new ForkJoinPool(); pool.invoke(new MatrixMultiplier(m1, m2, result, -1)); return result; } /** * Random matrix. * * @param M the m * @param N the n * @return the matrix */ public static Matrix randomMatrix(int M, int N) { Matrix m = new JBlasMatrix(M, N); //new SparseMatrix(M, N); Random random = new Random(); for (int row = 0; row < M; row++) { for (int col = 0; col < N; col++) { double val = random.nextDouble() * 100; if (val > 0) { m.set(row, col, val); } } } return m; } /** * The entry point of application. * * @param args the input arguments * @throws Exception the exception */ public static void main(String[] args) throws Exception { Matrix m1 = randomMatrix(10000, 10000); Matrix m2 = randomMatrix(10000, 10000); Stopwatch sw = Stopwatch.createStarted(); Matrix m3 = m1.multiply(m2); sw.stop(); System.out.println(sw); // System.out.println(m3); // // m1 = new SparseMatrix(2, 3); // m1.set(0, 0, 1); // m1.set(0, 1, 2); // m1.set(0, 2, 3); // m1.set(1, 0, 4); // m1.set(1, 1, 5); // m1.set(1, 2, 6); // // m2 = new SparseMatrix(3, 2); // m2.set(0, 0, 7); // m2.set(0, 1, 8); // m2.set(1, 0, 9); // m2.set(1, 1, 10); // m2.set(2, 0, 11); // m2.set(2, 1, 12); // // m3 = m1.multiply(m2); // System.out.println(m3); } private static class MatrixMultiplier extends RecursiveAction { private static final long serialVersionUID = -8570108050068635908L; final Matrix a, b, c; final int row; private MatrixMultiplier(Matrix a, Matrix b, Matrix c, int row) { this.a = a; this.b = b; this.c = c; this.row = row; } @Override protected void compute() { if (row == -1) { List<MatrixMultiplier> tasks = Lists.newArrayList(); for (int row = 0; row < a.rowDimension(); row++) { tasks.add(new MatrixMultiplier(a, b, c, row)); } invokeAll(tasks); } else { doCalculation(); } } void doCalculation() { for (int col = 0; col < b.columnDimension(); col++) { double sum = 0d; for (DoubleEntry entry : CollectionUtils.asIterable(a.getRow(row).nonZeroIterator())) { sum += entry.value * b.get(entry.index, col); } c.set(row, col, sum); } } } private static class MatrixRowOperation extends RecursiveAction { private static final long serialVersionUID = -8570108050068635908L; final Matrix a, b, c; final int row; final BinaryFunction function; private MatrixRowOperation(Matrix a, Matrix b, Matrix c, int row, BinaryFunction function) { this.a = a; this.b = b; this.c = c; this.row = row; this.function = function; } void doCalculation() { for (DoubleEntry entry : CollectionUtils.asIterable(a.getRow(row).iterator())) { c.set(row, entry.index, function.value(entry.value, b.get(row, entry.index))); } } @Override protected void compute() { if (row == -1) { List<MatrixRowOperation> tasks = Lists.newArrayList(); for (int row = 0; row < a.rowDimension(); row++) { tasks.add(new MatrixRowOperation(a, b, c, row, function)); } invokeAll(tasks); } else { doCalculation(); } } } }//END OF MatrixMath