org.briljantframework.array.base.BaseArrayRoutines.java Source code

Java tutorial

Introduction

Here is the source code for org.briljantframework.array.base.BaseArrayRoutines.java

Source

/**
 * The MIT License (MIT)
 *
 * Copyright (c) 2016 Isak Karlsson
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
 * associated documentation files (the "Software"), to deal in the Software without restriction,
 * including without limitation the rights to use, copy, modify, merge, publish, distribute,
 * sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all copies or
 * substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
 * NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */
package org.briljantframework.array.base;

import static org.briljantframework.array.StrideUtils.columnMajor;
import static org.briljantframework.array.StrideUtils.rowMajor;

import java.util.Comparator;

import org.apache.commons.math3.complex.Complex;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
import org.briljantframework.Check;
import org.briljantframework.array.Array;
import org.briljantframework.array.ArrayOperation;
import org.briljantframework.array.BaseArray;
import org.briljantframework.array.ComplexArray;
import org.briljantframework.array.DoubleArray;
import org.briljantframework.array.IntArray;
import org.briljantframework.array.LongArray;
import org.briljantframework.array.api.ArrayRoutines;
import org.briljantframework.complex.MutableComplex;
import org.briljantframework.exceptions.MultiDimensionMismatchException;
import org.briljantframework.statistics.FastStatistics;

/**
 * Base array routines implemented in Java.
 * 
 * @author Isak Karlsson
 */
public class BaseArrayRoutines implements ArrayRoutines {

    protected static final double LOG_2 = Math.log(2);
    protected static final double EPS = 1e-10;

    protected BaseArrayRoutines() {
    }

    @Override
    public double mean(DoubleArray x) {
        return x.reduce(0, Double::sum) / x.size();
    }

    @Override
    public DoubleArray mean(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::mean);
    }

    @Override
    public double var(DoubleArray x) {
        FastStatistics s = new FastStatistics();
        x.forEachDouble(s::addValue);
        return s.getVariance();
    }

    @Override
    public DoubleArray var(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::var);
    }

    @Override
    public double std(DoubleArray x) {
        return Math.sqrt(var(x));
    }

    @Override
    public DoubleArray std(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::std);
    }

    @Override
    public double min(DoubleArray x) {
        return x.reduce(Double.POSITIVE_INFINITY, Math::min);
    }

    @Override
    public int min(IntArray x) {
        return x.reduce(Integer.MAX_VALUE, Math::min);
    }

    @Override
    public long min(LongArray x) {
        return x.reduce(Long.MAX_VALUE, Math::min);
    }

    @Override
    public <T extends Comparable<T>> T min(Array<T> x) {
        return min(x, Comparable::compareTo);
    }

    @Override
    public <T> T min(Array<T> x, Comparator<T> cmp) {
        if (x.size() < 1) {
            return null;
        }
        return x.reduce(x.get(0), (o, n) -> {
            if (cmp.compare(o, n) < 0) {
                return o;
            } else {
                return n;
            }
        });
    }

    @Override
    public DoubleArray min(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::min);
    }

    @Override
    public IntArray min(int dim, IntArray x) {
        return x.reduceVectors(dim, this::min);
    }

    @Override
    public LongArray min(int dim, LongArray x) {
        return x.reduceVector(dim, this::min);
    }

    @Override
    public <T extends Comparable<T>> Array<T> min(int dim, Array<T> x) {
        return x.reduceVector(dim, this::min);
    }

    @Override
    public <T> Array<T> min(int dim, Array<T> x, Comparator<T> cmp) {
        return x.reduceVector(dim, v -> this.min(v, cmp));
    }

    @Override
    public double max(DoubleArray x) {
        return x.reduce(Double.NEGATIVE_INFINITY, Math::max);
    }

    @Override
    public int max(IntArray x) {
        return x.reduce(Integer.MIN_VALUE, Math::max);
    }

    @Override
    public long max(LongArray x) {
        return x.reduce(Long.MIN_VALUE, Math::max);
    }

    @Override
    public <T extends Comparable<T>> T max(Array<T> x) {
        return max(x, Comparable::compareTo);
    }

    @Override
    public <T> T max(Array<T> x, Comparator<T> cmp) {
        if (x.size() < 1) {
            return null;
        }
        return x.reduce(x.get(0), (o, n) -> {
            if (cmp.compare(o, n) > 0) {
                return o;
            } else {
                return n;
            }
        });
    }

    @Override
    public DoubleArray max(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::max);
    }

    @Override
    public IntArray max(int dim, IntArray x) {
        return x.reduceVectors(dim, this::max);
    }

    @Override
    public LongArray max(int dim, LongArray x) {
        return x.reduceVector(dim, this::max);
    }

    @Override
    public <T extends Comparable<T>> Array<T> max(int dim, Array<T> x) {
        return x.reduceVector(dim, this::max);
    }

    @Override
    public <T> Array<T> max(int dim, Array<T> x, Comparator<T> cmp) {
        return x.reduceVector(dim, v -> this.max(v, cmp));
    }

    @Override
    public double sum(DoubleArray x) {
        return x.reduce(0, Double::sum);
    }

    @Override
    public int sum(IntArray x) {
        return x.reduce(0, Integer::sum);
    }

    @Override
    public long sum(LongArray x) {
        return x.reduce(0, Long::sum);
    }

    @Override
    public Complex sum(ComplexArray x) {
        MutableComplex sum = new MutableComplex(0);
        for (int i = 0; i < x.size(); i++) {
            sum.plus(x.get(i));
        }
        return sum.toComplex();
    }

    @Override
    public DoubleArray sum(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::sum);
    }

    @Override
    public IntArray sum(int dim, IntArray x) {
        return x.reduceVectors(dim, this::sum);
    }

    @Override
    public LongArray sum(int dim, LongArray x) {
        return x.reduceVector(dim, this::sum);
    }

    /**
     * Returns the sum along the specified dimension.
     *
     * @param dim the dimension
     * @param x the array
     * @return an array of sums
     */
    @Override
    public ComplexArray sum(int dim, ComplexArray x) {
        return x.reduceVectors(dim, this::sum);
    }

    @Override
    public double prod(DoubleArray x) {
        double prod = x.get(0);
        for (int i = 1; i < x.size(); i++) {
            prod *= x.get(i);
        }
        return prod;
    }

    @Override
    public DoubleArray prod(int dim, DoubleArray x) {
        return x.reduceVectors(dim, this::prod);
    }

    @Override
    public DoubleArray cumsum(DoubleArray x) {
        DoubleArray n = x.newEmptyArray(x.rows(), x.columns());
        double sum = 0;
        for (int i = 0; i < n.size(); i++) {
            sum += x.get(i);
            n.set(i, sum);
        }
        return n;
    }

    @Override
    public DoubleArray cumsum(int dim, DoubleArray x) {
        DoubleArray n = x.newEmptyArray(x.rows(), x.columns());
        int vectors = x.vectors(dim);
        for (int i = 0; i < vectors; i++) {
            n.setVector(dim, i, cumsum(n.getVector(dim, i)));
        }

        return n;
    }

    @Override
    public double inner(DoubleArray a, DoubleArray b) {
        Check.size(a, b);
        double s = 0;
        for (int i = 0; i < a.size(); i++) {
            s += a.get(i) * b.get(i);
        }
        return s;
    }

    @Override
    public Complex inner(ComplexArray a, ComplexArray b) {
        return null;
    }

    @Override
    public Complex conjugateInner(ComplexArray a, ComplexArray b) {
        return null;
    }

    @Override
    public double norm2(DoubleArray a) {
        double sum = 0;
        for (int i = 0; i < a.size(); i++) {
            double v = a.get(i);
            sum += v * v;
        }

        return Math.sqrt(sum);
    }

    @Override
    public Complex norm2(ComplexArray a) {
        MutableComplex c = new MutableComplex(a.get(0).pow(2));
        for (int i = 1; i < a.size(); i++) {
            c.plus(a.get(i).pow(2));
        }
        return c.toComplex().sqrt();
    }

    @Override
    public double asum(DoubleArray a) {
        double sum = 0;
        for (int i = 0; i < a.size(); i++) {
            sum += Math.abs(a.get(i));
        }
        return sum;
    }

    @Override
    public double asum(ComplexArray a) {
        double s = 0;
        for (int i = 0; i < a.size(); i++) {
            s += a.get(i).abs();
        }
        return s;
    }

    @Override
    public int iamax(DoubleArray x) {
        int i = 0;
        double m = Math.abs(x.get(0));
        for (int j = 1; j < x.size(); j++) {
            double d = Math.abs(x.get(j));
            if (d > m) {
                i = j;
                m = d;
            }
        }
        return i;
    }

    @Override
    public int iamax(ComplexArray x) {
        return 0;
    }

    @Override
    public void scal(double alpha, DoubleArray x) {
        if (alpha == 1) {
            return;
        }
        final int n = x.size();
        for (int i = 0; i < n; i++) {
            x.set(i, x.get(i) * alpha);
        }
    }

    @Override
    public double trace(DoubleArray x) {
        int min = Math.min(x.rows(), x.columns());
        double sum = 0;
        for (int i = 0; i < min; i++) {
            sum += x.get(i, i);
        }
        return sum;
    }

    @Override
    public void axpy(double alpha, DoubleArray x, DoubleArray y) {
        Check.size(x, y);
        if (alpha == 0) {
            return;
        }
        int size = x.size();
        for (int i = 0; i < size; i++) {
            y.set(i, alpha * x.get(i) + y.get(i));
        }
    }

    @Override
    public void gemv(ArrayOperation transA, double alpha, DoubleArray a, DoubleArray x, double beta,
            DoubleArray y) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void ger(double alpha, DoubleArray x, DoubleArray y, DoubleArray a) {
        Check.all(x, y).argument(BaseArray::isVector, "x and y must be vectors");
        Check.dimension(x.size(), a.rows());
        Check.dimension(y.size(), a.columns());
        for (int i = 0; i < x.size(); i++) {
            for (int j = 0; j < y.size(); j++) {
                a.set(i, j, alpha * x.get(i) * y.get(j));
            }
        }
    }

    @Override
    public void gemm(ArrayOperation transA, ArrayOperation transB, double alpha, DoubleArray a, DoubleArray b,
            double beta, DoubleArray c) {

        int thisRows = a.rows();
        int thisCols = a.columns();
        if (transA.isTrue()) {
            thisRows = a.columns();
            thisCols = a.rows();
        }
        int otherRows = b.rows();
        int otherColumns = b.columns();
        if (transB.isTrue()) {
            otherRows = b.columns();
            otherColumns = b.rows();
        }

        if (thisCols != otherRows) {
            throw new MultiDimensionMismatchException(thisRows, thisCols, otherRows, otherColumns);
        }
        int m = a.size(transA == ArrayOperation.KEEP ? 0 : 1);
        int n = b.size(transB == ArrayOperation.KEEP ? 1 : 0);
        int dk = a.size(transA == ArrayOperation.KEEP ? 1 : 0);
        if (m != c.size(0) || n != c.size(1)) {
            throw new IllegalArgumentException(
                    String.format("a has size (%d,%d), b has size (%d,%d), c has size (%d, %d)", m, dk, dk, n,
                            c.size(0), c.size(1)));
        }

        for (int row = 0; row < thisRows; row++) {
            for (int col = 0; col < otherColumns; col++) {
                double sum = 0.0;
                for (int k = 0; k < thisCols; k++) {
                    int thisIndex = transA.isTrue() ? rowMajor(row, k, thisRows, thisCols)
                            : columnMajor(0, row, k, thisRows, thisCols);
                    int otherIndex = transB.isTrue() ? rowMajor(k, col, otherRows, otherColumns)
                            : columnMajor(0, k, col, otherRows, otherColumns);
                    sum += a.get(thisIndex) * b.get(otherIndex);
                }
                c.set(row, col, alpha * sum + beta * c.get(row, col));
            }
        }
    }

    @Override
    public <T extends BaseArray<T>> void copy(T from, T to) {
        Check.size(from, to);
        for (int i = 0; i < from.size(); i++) {
            to.set(i, from, i);
        }
    }

    @Override
    public <T extends BaseArray<T>> void swap(T a, T b) {
        Check.dimension(a, b);
        T tmp = a.newEmptyArray(1);
        for (int i = 0; i < a.size(); i++) {
            tmp.set(0, a, i);
            a.set(i, b, i);
            b.set(i, tmp, 0);
        }
    }

    @Override
    public DoubleArray sin(DoubleArray array) {
        return array.map(Math::sin);
    }

    @Override
    public ComplexArray sin(ComplexArray array) {
        return array.map(Complex::sin);
    }

    @Override
    public DoubleArray cos(DoubleArray array) {
        return array.map(Math::cos);
    }

    @Override
    public ComplexArray cos(ComplexArray array) {
        return array.map(Complex::cos);
    }

    @Override
    public DoubleArray tan(DoubleArray array) {
        return array.map(Math::tan);
    }

    @Override
    public ComplexArray tan(ComplexArray array) {
        return array.map(Complex::tan);
    }

    @Override
    public DoubleArray asin(DoubleArray array) {
        return array.map(Math::asin);
    }

    @Override
    public ComplexArray asin(ComplexArray array) {
        return array.map(Complex::asin);
    }

    @Override
    public DoubleArray acos(DoubleArray array) {
        return array.map(Math::acos);
    }

    @Override
    public ComplexArray acos(ComplexArray array) {
        return array.map(Complex::acos);
    }

    @Override
    public DoubleArray atan(DoubleArray array) {
        return array.map(Math::atan);
    }

    @Override
    public ComplexArray atan(ComplexArray array) {
        return array.map(Complex::atan);
    }

    @Override
    public DoubleArray sinh(DoubleArray array) {
        return array.map(Math::sinh);
    }

    @Override
    public ComplexArray sinh(ComplexArray array) {
        return array.map(Complex::sinh);
    }

    @Override
    public DoubleArray cosh(DoubleArray array) {
        return array.map(Math::cosh);
    }

    @Override
    public ComplexArray cosh(ComplexArray array) {
        return array.map(Complex::cosh);
    }

    @Override
    public DoubleArray tanh(DoubleArray array) {
        return array.map(Math::tanh);
    }

    @Override
    public ComplexArray tanh(ComplexArray array) {
        return array.map(Complex::tanh);
    }

    @Override
    public DoubleArray exp(DoubleArray array) {
        return array.map(Math::exp);
    }

    @Override
    public ComplexArray exp(ComplexArray array) {
        return array.map(Complex::exp);
    }

    @Override
    public DoubleArray cbrt(DoubleArray array) {
        return array.map(Math::cbrt);
    }

    @Override
    public DoubleArray ceil(DoubleArray array) {
        return array.map(Math::ceil);
    }

    @Override
    public ComplexArray ceil(ComplexArray array) {
        return array.map(v -> new Complex(Math.ceil(v.getReal()), Math.ceil(v.getImaginary())));
    }

    @Override
    public DoubleArray floor(DoubleArray array) {
        return array.map(Math::floor);
    }

    @Override
    public ComplexArray floor(ComplexArray array) {
        return array.map(v -> new Complex(Math.floor(v.getReal()), Math.floor(v.getImaginary())));
    }

    @Override
    public IntArray abs(IntArray array) {
        return array.map(Math::abs);
    }

    @Override
    public LongArray abs(LongArray array) {
        return array.map(Math::abs);
    }

    @Override
    public DoubleArray abs(DoubleArray array) {
        return array.map(Math::abs);
    }

    @Override
    public DoubleArray abs(ComplexArray array) {
        return array.mapToDouble(Complex::abs);
    }

    @Override
    public DoubleArray scalb(DoubleArray array, int scaleFactor) {
        return array.map(v -> Math.scalb(v, scaleFactor));
    }

    @Override
    public DoubleArray sqrt(DoubleArray array) {
        return array.map(Math::sqrt);
    }

    @Override
    public ComplexArray sqrt(ComplexArray array) {
        return array.map(Complex::sqrt);
    }

    @Override
    public DoubleArray log(DoubleArray array) {
        return array.map(Math::log);
    }

    @Override
    public ComplexArray log(ComplexArray array) {
        return array.map(Complex::log);
    }

    @Override
    public DoubleArray log2(DoubleArray array) {
        return array.map(x -> Math.log(x) / LOG_2);
    }

    @Override
    public DoubleArray pow(DoubleArray in, double power) {
        if (Precision.equals(power, 2, EPS)) {
            return in.map(x -> x * x);
        } else if (Precision.equals(power, 3, EPS)) {
            return in.map(x -> x * x * x);
        } else if (Precision.equals(power, 4, EPS)) {
            return in.map(x -> x * x * x * x);
        } else {
            return in.map(x -> FastMath.pow(x, power));
        }
    }

    @Override
    public DoubleArray log10(DoubleArray in) {
        return in.map(Math::log10);
    }

    @Override
    public DoubleArray signum(DoubleArray in) {
        return in.map(Math::signum);
    }

    @Override
    public LongArray round(DoubleArray in) {
        return in.asLong().newEmptyArray(in.getShape()).assign(in, Math::round);
    }
}