Source code

Java tutorial


Here is the source code for


 * Copyright (c) 2018 by Andrew Charneski.
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.

package com.simiacryptus.mindseye.lang;

import com.simiacryptus.mindseye.test.TestUtil;
import org.apache.commons.lang3.ArrayUtils;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.function.*;

 * A multi-dimensional array of data. Represented internally as a single double[] array. This class is central to data
 * handling in MindsEye, and may have some odd-looking or suprising optimizations.
public final class Tensor extends ReferenceCountingBase implements Serializable {

     * The constant json_precision.
    public static DataSerializer json_precision = SerialPrecision.Float;
     * The Dimensions.
    protected final int[] dimensions;
     * The Strides.
    protected final int[] strides;
     * The Data.
    protected volatile double[] data;

    protected volatile UUID id;

     * Instantiates a new Tensor.
    private Tensor() {
        data = null;
        strides = null;
        dimensions = null;

     * Instantiates a new Tensor.
     * @param ds the ds
    public Tensor(@Nonnull final double... ds) {
        this(ds, ds.length);

     * Instantiates a new Tensor.
     * @param data the data
     * @param dims the dims
    public Tensor(@Nullable final double[] data, @Nonnull final int... dims) {
        if (Tensor.length(dims) > Integer.MAX_VALUE)
            throw new IllegalArgumentException();
        if (null != data && Tensor.length(dims) != data.length)
            throw new IllegalArgumentException(Arrays.toString(dims) + " != " + data.length);
        dimensions = (null == dims || 0 == dims.length) ? new int[] {} : Arrays.copyOf(dims, dims.length);
        strides = Tensor.getSkips(dims);
        // = data;// Arrays.copyOf(data, data.length);
        if (null != data) {
   = RecycleBin.DOUBLES.copyOf(data, data.length);
        assert isValid();
        //assert (null == data || Tensor.length(dims) == data.length);

    private Tensor(int[] dims, @Nullable double[] data) {
        this(dims, Tensor.getSkips(dims), data);

    private Tensor(int[] dimensions, int[] strides, @Nullable double[] data) {
        if (Tensor.length(dimensions) >= Integer.MAX_VALUE)
            throw new IllegalArgumentException();
        assert null == data || data.length == Tensor.length(dimensions);
        this.dimensions = dimensions;
        this.strides = strides; = data;
        assert isValid();

     * Instantiates a new Tensor.
     * @param data the data
     * @param dims the dims
    public Tensor(@Nullable final float[] data, @Nonnull final int... dims) {
        if (Tensor.length(dims) >= Integer.MAX_VALUE)
            throw new IllegalArgumentException();
        dimensions = Arrays.copyOf(dims, dims.length);
        strides = Tensor.getSkips(dims);
        if (null != data) {
   = RecycleBin.DOUBLES.obtain(data.length);// Arrays.copyOf(data, data.length);
            Arrays.parallelSetAll(, i -> {
                final double v = data[i];
                return Double.isFinite(v) ? v : 0;
            assert -> Double.isFinite(v));
        assert isValid();
        //assert (null == data || Tensor.length(dims) == data.length);

     * Instantiates a new Tensor.
     * @param dims the dims
    public Tensor(@Nonnull final int... dims) {
        this((double[]) null, dims);
        assert dims.length > 0;

     * From json tensor.
     * @param json      the json
     * @param resources the resources
     * @return the tensor
    public static Tensor fromJson(@Nullable final JsonElement json, @Nullable Map<CharSequence, byte[]> resources) {
        if (null == json)
            return null;
        if (json.isJsonArray()) {
            final JsonArray array = json.getAsJsonArray();
            final int size = array.size();
            if (array.get(0).isJsonPrimitive()) {
                final double[] doubles = IntStream.range(0, size).mapToObj(i -> {
                    return array.get(i);
                }).mapToDouble(element -> {
                    return element.getAsDouble();
                Tensor tensor = new Tensor(doubles);
                assert tensor.isValid();
                return tensor;
            } else {
                final List<Tensor> elements = IntStream.range(0, size).mapToObj(i -> {
                    return array.get(i);
                }).map(element -> {
                    return Tensor.fromJson(element, resources);
                final int[] dimensions = elements.get(0).getDimensions();
                if (! -> Arrays.equals(dimensions, t.getDimensions()))) {
                    throw new IllegalArgumentException();
                final int[] newDdimensions = Arrays.copyOf(dimensions, dimensions.length + 1);
                newDdimensions[dimensions.length] = size;
                final Tensor tensor = new Tensor(newDdimensions);
                final double[] data = tensor.getData();
                for (int i = 0; i < size; i++) {
                    final double[] e = elements.get(i).getData();
                    System.arraycopy(e, 0, data, i * e.length, e.length);
                for (@Nonnull
                Tensor t : elements) {
                assert tensor.isValid();
                return tensor;
        } else if (json.isJsonObject()) {
            JsonObject jsonObject = json.getAsJsonObject();
            int[] dims = fromJsonArray(jsonObject.getAsJsonArray("length"));
            Tensor tensor = new Tensor(dims);
            SerialPrecision precision = SerialPrecision
            JsonElement base64 = jsonObject.get("base64");
            if (null == base64) {
                if (null == resources)
                    throw new IllegalArgumentException("No Data Resources");
                CharSequence resourceId = jsonObject.getAsJsonPrimitive("resource").getAsString();
                tensor.setBytes(resources.get(resourceId), precision);
            } else {
                tensor.setBytes(Base64.getDecoder().decode(base64.getAsString()), precision);
            assert tensor.isValid();
            JsonElement id = jsonObject.get("id");
            if (null != id) {
            return tensor;
        } else {
            Tensor tensor = new Tensor(json.getAsJsonPrimitive().getAsDouble());
            assert tensor.isValid();
            return tensor;

    private static double bound8bit(final double value) {
        final int max = 0xFF;
        final int min = 0;
        return value < min ? min : value > max ? max : value;

    private static int bound8bit(final int value) {
        final int max = 0xFF;
        final int min = 0;
        return value < min ? min : value > max ? max : value;

     * Dim l long.
     * @param dims the dims
     * @return the long
    public static int length(@Nonnull int... dims) {
        long total = 1;
        for (final int dim : dims) {
            total *= dim;
        return (int) total;

     * From rgb tensor.
     * @param img the img
     * @return the tensor
    public static Tensor fromRGB(@Nonnull final BufferedImage img) {
        final int width = img.getWidth();
        final int height = img.getHeight();
        final Tensor a = new Tensor(width, height, 3);
        IntStream.range(0, width).parallel().forEach(x -> {
            final int[] coords = { 0, 0, 0 };
            IntStream.range(0, height).forEach(y -> {
                coords[0] = x;
                coords[1] = y;
                coords[2] = 0;
                a.set(coords, img.getRGB(x, y) & 0xFF);
                coords[2] = 1;
                a.set(coords, img.getRGB(x, y) >> 8 & 0xFF);
                coords[2] = 2;
                a.set(coords, img.getRGB(x, y) >> 16 & 0x0FF);
        return a;

     * Get doubles double [ ].
     * @param stream the stream
     * @param dim    the length
     * @return the double [ ]
    public static double[] getDoubles(@Nonnull final DoubleStream stream, final int dim) {
        final double[] doubles = RecycleBin.DOUBLES.obtain(dim);
        stream.forEach(new DoubleConsumer() {
            int j = 0;

            public void accept(final double value) {
                doubles[j++] = value;
        return doubles;

    private static int[] getSkips(@Nonnull final int[] dims) {
        final int[] skips = new int[dims.length];
        for (int i = 0; i < skips.length; i++) {
            if (i == 0) {
                skips[0] = 1;
            } else {
                skips[i] = skips[i - 1] * dims[i - 1];
        return skips;

     * Product tensor.
     * @param left  the left
     * @param right the right
     * @return the tensor
    public static Tensor product(@Nonnull final Tensor left, @Nonnull final Tensor right) {
        if (left.length() == 1 && right.length() != 1)
            return Tensor.product(right, left);
        assert left.length() == right.length() || 1 == right.length();
        final Tensor result = new Tensor(left.getDimensions());
        final double[] resultData = result.getData();
        final double[] leftData = left.getData();
        final double[] rightData = right.getData();
        for (int i = 0; i < resultData.length; i++) {
            final double l = leftData[i];
            final double r = rightData[1 == rightData.length ? 0 : i];
            resultData[i] = l * r;
        return result;

     * To doubles double [ ].
     * @param data the data
     * @return the double [ ]
    public static double[] toDoubles(@Nonnull final float[] data) {
        final double[] buffer = RecycleBin.DOUBLES.obtain(data.length);
        for (int i = 0; i < data.length; i++) {
            buffer[i] = data[i];
        return buffer;

     * To floats float [ ].
     * @param data the data
     * @return the float [ ]
    public static float[] toFloats(@Nonnull final double[] data) {
        final float[] buffer = new float[data.length];
        for (int i = 0; i < data.length; i++) {
            buffer[i] = (float) data[i];
        return buffer;

     * To json array json array.
     * @param ints the ints
     * @return the json array
    public static JsonArray toJsonArray(@Nonnull int[] ints) {
        JsonArray dim = new JsonArray();
        for (int i = 0; i < ints.length; i++) {
            dim.add(new JsonPrimitive(ints[i]));
        return dim;

     * From json array int [ ].
     * @param ints the ints
     * @return the int [ ]
    public static int[] fromJsonArray(@Nonnull JsonArray ints) {
        int[] array = new int[ints.size()];
        for (int i = 0; i < ints.size(); i++) {
            array[i] = ints.get(i).getAsInt();
        return array;

     * Reverse dimensions tensor.
     * @param tensor the tensor
     * @return the tensor
    public static Tensor reverseDimensions(@Nonnull Tensor tensor) {
        return tensor.rearrange(Tensor::reverse);

     * Permute int [ ].
     * @param key        the key
     * @param data       the data
     * @param dimensions the dimensions
     * @return the int [ ]
    public static int[] permute(@Nonnull int[] key, int[] data, final int[] dimensions) {
        int[] copy = new int[key.length];
        for (int i = 0; i < key.length; i++) {
            int k = key[i];
            if (k == Integer.MAX_VALUE) {
                copy[i] = dimensions[0] - data[0] - 1;
            } else if (k < 0) {
                copy[i] = dimensions[-k] - data[-k] - 1;
            } else {
                copy[i] = data[k];
        return copy;

     * Reverse int [ ].
     * @param dimensions the dimensions
     * @return the int [ ]
    public static int[] reverse(@Nonnull int[] dimensions) {
        int[] copy = Arrays.copyOf(dimensions, dimensions.length);
        return copy;

     * Pretty print string.
     * @param doubles the doubles
     * @return the string
    public static CharSequence prettyPrint(double[] doubles) {
        Tensor t = new Tensor(doubles);
        String prettyPrint = t.prettyPrint();
        return prettyPrint;

     * Get pixel double [ ].
     * @param tensor the tensor
     * @param x      the x
     * @param y      the y
     * @param bands  the bands
     * @return the double [ ]
    public static double[] getPixel(final Tensor tensor, final int x, final int y, final int bands) {
        return IntStream.range(0, bands).mapToDouble(band -> tensor.get(x, y, band)).toArray();

     * Reduce tensor.
     * @return the tensor
    public Tensor sumChannels() {
        int[] dimensions = getDimensions();
        Tensor self = this;
        return new Tensor(dimensions[0], dimensions[1], 1).setByCoord(c -> {
            int[] coords = c.getCoords();
            return IntStream.range(0, dimensions[2]).mapToDouble(j -> self.get(coords[0], coords[1], j)).sum();

     * Gets pixel stream.
     * @return the pixel stream
    public Stream<double[]> getPixelStream() {
        int[] dimensions = getDimensions();
        int width = dimensions[0];
        int height = dimensions[1];
        int bands = dimensions[2];
        return IntStream.range(0, width).mapToObj(x -> x).parallel().flatMap(x -> {
            return IntStream.range(0, height).mapToObj(y -> y).map(y -> {
                return getPixel(this, x, y, bands);

     * Rescale rms tensor.
     * @param rms the rms
     * @return the tensor
    public Tensor rescaleRms(final double rms) {
        return scale(rms / rms());

     * Normalize distribution tensor.
     * @return the tensor
    public Tensor normalizeDistribution() {
        double[] sortedValues =;
        Tensor result = map(
                v -> Math.abs(((double) Arrays.binarySearch(sortedValues, v)) / ((double) sortedValues.length)));
        return result;

     * Reorder dimensions tensor.
     * @param fn the fn
     * @return the tensor
    public Tensor rearrange(@Nonnull UnaryOperator<int[]> fn) {
        return rearrange(fn, fn.apply(getDimensions()));

     * Reorder dimensions tensor.
     * @param fn         the fn
     * @param outputDims the output dims
     * @return the tensor
    public Tensor rearrange(@Nonnull UnaryOperator<int[]> fn, int[] outputDims) {
        Tensor result = new Tensor(outputDims);
        coordStream(false).forEach(c -> {
            int[] inCoords = c.getCoords();
            int[] outCoords = fn.apply(inCoords);
            result.set(outCoords, get(c));
        return result;

     * Is valid boolean.
     * @return the boolean
    public boolean isValid() {
        return !isFinalized() && (null == || == Tensor.length(dimensions));

     * Accum.
     * @param tensor the tensor
    public void addInPlace(@Nonnull final Tensor tensor) {
        assert Arrays.equals(getDimensions(), tensor.getDimensions()) : Arrays.toString(getDimensions()) + " != "
                + Arrays.toString(tensor.getDimensions());
        setParallelByIndex(c -> get(c) + tensor.get(c));

     * Add.
     * @param coords the coords
     * @param value  the value
    public void add(@Nonnull final Coordinate coords, final double value) {
        add(coords.getIndex(), value);

     * Add tensor.
     * @param index the index
     * @param value the value
     * @return the tensor
    public final Tensor add(final int index, final double value) {
        getData()[index] += value;
        return this;

     * Add.
     * @param coords the coords
     * @param value  the value
    public void add(@Nonnull final int[] coords, final double value) {
        add(index(coords), value);

     * Add right.
     * @param right the right
     * @return the right
    public Tensor add(@Nonnull final Tensor right) {
        assert Arrays.equals(getDimensions(), right.getDimensions());
        return mapCoords((c) -> get(c) + right.get(c));

     * Add and free tensor.
     * @param right the right
     * @return the tensor
    public Tensor addAndFree(@Nonnull final Tensor right) {
        if (1 == currentRefCount()) {
            return this;
        } else {
            assert Arrays.equals(getDimensions(), right.getDimensions());
            return mapCoordsAndFree((c) -> get(c) + right.get(c));

     * Coord stream stream.
     * @param parallel the safe
     * @return the stream
    public Stream<Coordinate> coordStream(boolean parallel) {
        //ConcurrentHashSet<Object> distinctBuffer = new ConcurrentHashSet<>();
        //assert distinctBuffer.add(coordinate.copy()) : String.format("Duplicate: %s in %s", coordinate, distinctBuffer);
        return Iterator<Coordinate>() {

            int cnt = 0;
            Coordinate coordinate = new Coordinate();
            int[] val = new int[dimensions.length];
            int[] safeCopy = new int[dimensions.length];

            public boolean hasNext() {
                return cnt < length();

            public synchronized Coordinate next() {
                if (0 < cnt) {
                    for (int i = 0; i < val.length; i++) {
                        if (++val[i] >= dimensions[i]) {
                            val[i] = 0;
                        } else {
                System.arraycopy(val, 0, safeCopy, 0, val.length);
                return parallel ? coordinate.copy() : coordinate;
        }, length(), Spliterator.ORDERED), parallel);

     * Dim int.
     * @return the int
    public int length() {
        if (null != data) {
            return data.length;
        } else {
            return Tensor.length(dimensions);

     * Copy tensor.
     * @return the tensor
    public Tensor copy() {
        return new Tensor(RecycleBin.DOUBLES.copyOf(getData(), getData().length),
                Arrays.copyOf(dimensions, dimensions.length));

    protected void _free() {
        if (null != data) {
            if (RecycleBin.DOUBLES.want(data.length)) {
                RecycleBin.DOUBLES.recycle(data, data.length);
            data = null;

    public boolean equals(@Nullable final Object obj) {
        if (this == obj) {
            return true;
        if (obj == null) {
            return false;
        if (getClass() != obj.getClass()) {
            return false;
        final Tensor other = (Tensor) obj;
        if (0 == currentRefCount())
            return false;
        if (0 == other.currentRefCount())
            return false;
        if (!Arrays.equals(dimensions, other.dimensions)) {
            return false;
        return Arrays.equals(getData(), other.getData());

     * Get double.
     * @param coords the coords
     * @return the double
    public double get(@Nonnull final Coordinate coords) {
        final double v = getData()[coords.getIndex()];
        return v;

     * Get double.
     * @param index the index
     * @return the double
    public double get(final int index) {
        return getData()[index];

     * Get double.
     * @param c1 the c 1
     * @param c2 the c 2
     * @return the double
    public double get(final int c1, final int c2) {
        return getData()[index(c1, c2)];

     * Get double.
     * @param c1 the c 1
     * @param c2 the c 2
     * @param c3 the c 3
     * @return the double
    public double get(final int c1, final int c2, final int c3) {
        return getData()[index(c1, c2, c3)];

     * Get double.
     * @param c1     the c 1
     * @param c2     the c 2
     * @param c3     the c 3
     * @param c4     the c 4
     * @param coords the coords
     * @return the double
    public double get(final int c1, final int c2, final int c3, final int c4, final int... coords) {
        return getData()[index(c1, c2, c3, c4, coords)];

     * Get.
     * @param bufferArray the buffer array
    public void get(@Nonnull final double[] bufferArray) {
        System.arraycopy(getData(), 0, bufferArray, 0, length());

     * Get double.
     * @param coords the coords
     * @return the double
    public double get(@Nonnull final int[] coords) {
        return getData()[index(coords)];

     * Get data double [ ].
     * @return the double [ ]
    public double[] getData() {
        if (null == data) {
            synchronized (this) {
                if (null == data) {
                    final int length = Tensor.length(dimensions);
                    data = RecycleBin.DOUBLES.obtain(length);
                    assert null != data;
                    assert length == data.length;
        assert isValid();
        assert null != data;
        return data;

     * Get dimensions int [ ].
     * @return the int [ ]
    public final int[] getDimensions() {
        return Arrays.copyOf(dimensions, dimensions.length);

    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + Arrays.hashCode(getData());
        result = prime * result + Arrays.hashCode(dimensions);
        return result;

     * Get data as floats float [ ].
     * @return the float [ ]
    public float[] getDataAsFloats() {
        return Tensor.toFloats(getData());

     * Index int.
     * @param c1 the c 1
     * @return the int
    public int index(final int c1) {
        int v = 0;
        v += strides[0] * c1;
        return v;
        // return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();

     * Index int.
     * @param c1 the c 1
     * @param c2 the c 2
     * @return the int
    public int index(final int c1, final int c2) {
        int v = 0;
        v += strides[0] * c1;
        v += strides[1] * c2;
        return v;
        // return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();

     * Index int.
     * @param c1 the c 1
     * @param c2 the c 2
     * @param c3 the c 3
     * @return the int
    public int index(final int c1, final int c2, final int c3) {
        int v = 0;
        v += strides[0] * c1;
        v += strides[1] * c2;
        v += strides[2] * c3;
        return v;
        // return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();

     * Index int.
     * @param coords the coords
     * @return the int
    public int index(@Nonnull final Coordinate coords) {
        return coords.getIndex();

     * Index int.
     * @param c1     the c 1
     * @param c2     the c 2
     * @param c3     the c 3
     * @param c4     the c 4
     * @param coords the coords
     * @return the int
    public int index(final int c1, final int c2, final int c3, final int c4, @Nullable final int... coords) {
        int v = 0;
        v += strides[0] * c1;
        v += strides[1] * c2;
        v += strides[2] * c3;
        v += strides[3] * c4;
        if (null != coords && 0 < coords.length) {
            for (int i = 0; 4 + i < strides.length && i < coords.length; i++) {
                v += strides[4 + i] * coords[4 + i];
        return v;
        // return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();

     * L 1 double.
     * @return the double
    public double l1() {

     * L 2 double.
     * @return the double
    public double l2() {
        return Math.sqrt( -> x * x).sum());

     * Index int.
     * @param coords the coords
     * @return the int
    public int index(@Nonnull final int[] coords) {
        int v = 0;
        for (int i = 0; i < strides.length && i < coords.length; i++) {
            v += strides[i] * coords[i];
        return v;
        // return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();

     * Map tensor.
     * @param f the f
     * @return the tensor
    public Tensor map(@Nonnull final DoubleUnaryOperator f) {
        final double[] data = getData();
        Tensor tensor = new Tensor(dimensions);
        final double[] cpy = tensor.getData();
        IntStream.range(0, data.length).parallel().forEach(i -> cpy[i] = f.applyAsDouble(data[i]));
        return tensor;

     * Map and free tensor.
     * @param f the f
     * @return the tensor
    public Tensor mapAndFree(@Nonnull final DoubleUnaryOperator f) {
        final double[] data = getData();
        final double[] cpy = new double[data.length];
        for (int i = 0; i < data.length; i++) {
            final double x = data[i];
            // assert Double.isFinite(x);
            final double v = f.applyAsDouble(x);
            // assert Double.isFinite(v);
            cpy[i] = v;
        Tensor tensor = new Tensor(cpy, dimensions);
        return tensor;

     * Map coords tensor.
     * @param f the f
     * @return the tensor
    public Tensor mapCoords(@Nonnull final ToDoubleFunction<Coordinate> f) {
        return mapCoords(f, false);

     * Map coords and free tensor.
     * @param f the f
     * @return the tensor
    public Tensor mapCoordsAndFree(@Nonnull final ToDoubleFunction<Coordinate> f) {
        return mapCoordsAndFree(f, false);

     * Map coords tensor.
     * @param f        the f
     * @param parallel the parallel
     * @return the tensor
    public Tensor mapCoords(@Nonnull final ToDoubleFunction<Coordinate> f, boolean parallel) {
        return new Tensor(Tensor.getDoubles(coordStream(parallel).mapToDouble(i -> f.applyAsDouble(i)), length()),

     * Map coords and free tensor.
     * @param f        the f
     * @param parallel the parallel
     * @return the tensor
    public Tensor mapCoordsAndFree(@Nonnull final ToDoubleFunction<Coordinate> f, boolean parallel) {
        Tensor tensor = new Tensor(
                Tensor.getDoubles(coordStream(parallel).mapToDouble(i -> f.applyAsDouble(i)), length()),
        return tensor;

     * Map index tensor.
     * @param f the f
     * @return the tensor
    public Tensor mapIndex(@Nonnull final TupleOperator f) {
        return new Tensor(
                Tensor.getDoubles(IntStream.range(0, length()).mapToDouble(i -> f.eval(get(i), i)), length()),

     * Mean double.
     * @return the double
    public double mean() {
        return sum() / length();

     * Map parallel tensor.
     * @param f the f
     * @return the tensor
    public Tensor mapParallel(@Nonnull final DoubleUnaryOperator f) {
        final double[] data = getData();
        return new Tensor(Tensor.getDoubles(IntStream.range(0, length()).mapToDouble(i -> f.applyAsDouble(data[i])),
                length()), dimensions);

     * Minus tensor.
     * @param right the right
     * @return the tensor
    public Tensor minus(@Nonnull final Tensor right) {
        if (!Arrays.equals(getDimensions(), right.getDimensions())) {
            throw new IllegalArgumentException(
                    Arrays.toString(getDimensions()) + " != " + Arrays.toString(right.getDimensions()));
        final Tensor copy = new Tensor(getDimensions());
        final double[] thisData = getData();
        final double[] rightData = right.getData();
        Arrays.parallelSetAll(copy.getData(), i -> thisData[i] - rightData[i]);
        return copy;

     * Pretty printGroups string.
     * @return the string
    public String prettyPrint() {
        return toString(true);

     * Pretty print and free string.
     * @return the string
    public String prettyPrintAndFree() {
        String prettyPrint = prettyPrint();
        return prettyPrint;

     * Multiply tensor.
     * @param d the d
     * @return the tensor
    public Tensor multiply(final double d) {
        final Tensor tensor = new Tensor(getDimensions());
        final double[] resultData = tensor.getData();
        final double[] thisData = getData();
        for (int i = 0; i < thisData.length; i++) {
            resultData[i] = d * thisData[i];
        return tensor;

     * Rms double.
     * @return the double
    public double rms() {
        return Math.sqrt(sumSq() / length());

     * Reduce parallel tensor.
     * @param right the right
     * @param f     the f
     * @return the tensor
    public Tensor reduceParallel(@Nonnull final Tensor right, @Nonnull final DoubleBinaryOperator f) {
        if (!Arrays.equals(right.getDimensions(), getDimensions())) {
            throw new IllegalArgumentException(
                    Arrays.toString(right.getDimensions()) + " != " + Arrays.toString(getDimensions()));
        final double[] dataL = getData();
        final double[] dataR = right.getData();
        return new Tensor(Tensor.getDoubles(
                IntStream.range(0, length()).mapToDouble(i -> f.applyAsDouble(dataL[i], dataR[i])), length()),

     * Round tensor.
     * @param precision the precision
     * @return the tensor
    public Tensor round(final int precision) {
        if (precision > 8)
            return this;
        if (precision < 1)
            throw new IllegalArgumentException();
        return round(precision, 10);

     * Round tensor.
     * @param precision the precision
     * @param base      the base
     * @return the tensor
    public Tensor round(final int precision, final int base) {
        return map(v -> {
            final double units = Math.pow(base, Math.ceil(Math.log(v) / Math.log(base)) - precision);
            return Math.round(v / units) * units;

     * Scale tensor.
     * @param d the d
     * @return the tensor
    public Tensor scale(final double d) {
        return map(v -> v * d);

     * Scale tensor.
     * @param d the d
     * @return the tensor
    public Tensor scaleInPlace(final double d) {
        final double[] data = getData();
        for (int i = 0; i < data.length; i++) {
            data[i] *= d;
        return this;

     * Set.
     * @param coords the coords
     * @param value  the value
    public void set(@Nonnull final Coordinate coords, final double value) {
        if (Double.isFinite(value))
            set(coords.getIndex(), value);

     * Set tensor.
     * @param data the data
     * @return the tensor
    public Tensor set(final double[] data) {
        for (int i = 0; i < getData().length; i++) {
            getData()[i] = data[i];
        return this;

     * Fill tensor.
     * @param f the f
     * @return the tensor
    public Tensor set(@Nonnull final DoubleSupplier f) {
        Arrays.setAll(getData(), i -> f.getAsDouble());
        return this;

     * Set.
     * @param coord1 the coord 1
     * @param coord2 the coord 2
     * @param value  the value
    public void set(final int coord1, final int coord2, final double value) {
        assert Double.isFinite(value);
        set(index(coord1, coord2), value);

     * Set.
     * @param coord1 the coord 1
     * @param coord2 the coord 2
     * @param coord3 the coord 3
     * @param value  the value
    public void set(final int coord1, final int coord2, final int coord3, final double value) {
        assert Double.isFinite(value);
        set(index(coord1, coord2, coord3), value);

     * Set.
     * @param coord1 the coord 1
     * @param coord2 the coord 2
     * @param coord3 the coord 3
     * @param coord4 the coord 4
     * @param value  the value
    public void set(final int coord1, final int coord2, final int coord3, final int coord4, final double value) {
        assert Double.isFinite(value);
        set(index(coord1, coord2, coord3, coord4), value);

     * Set tensor.
     * @param index the index
     * @param value the value
     * @return the tensor
    public Tensor set(final int index, final double value) {
        // assert Double.isFinite(value);
        getData()[index] = value;
        return this;

     * Set.
     * @param coords the coords
     * @param value  the value
    public void set(@Nonnull final int[] coords, final double value) {
        assert Double.isFinite(value);
        set(index(coords), value);

     * Set tensor.
     * @param f the f
     * @return the tensor
    public Tensor set(@Nonnull final IntToDoubleFunction f) {
        Arrays.parallelSetAll(getData(), f);
        return this;

     * Set.
     * @param right the right
     * @return the tensor
    public Tensor set(@Nonnull final Tensor right) {
        assert length() == right.length();
        final double[] rightData = right.getData();
        Arrays.parallelSetAll(getData(), i -> rightData[i]);
        return this;

     * Sets all.
     * @param v the v
     * @return the all
    public Tensor setAll(final double v) {
        final double[] data = getData();
        for (int i = 0; i < data.length; i++) {
            data[i] = v;
        return this;

     * Fill by coord tensor.
     * @param f the f
     * @return the tensor
    public Tensor setByCoord(@Nonnull final ToDoubleFunction<Coordinate> f) {
        return setByCoord(f, true);

     * Fill by coord tensor.
     * @param f        the f
     * @param parallel the parallel
     * @return the tensor
    public Tensor setByCoord(@Nonnull final ToDoubleFunction<Coordinate> f, boolean parallel) {
        coordStream(parallel).forEach(c -> set(c, f.applyAsDouble(c)));
        return this;

     * Sum double.
     * @return the double
    public double sum() {
        double v = 0;
        for (final double element : getData()) {
            v += element;
        // assert Double.isFinite(v);
        return v;

     * Sum sq double.
     * @return the double
    public double sumSq() {
        double v = 0;
        for (final double element : getData()) {
            v += element * element;
        // assert Double.isFinite(v);
        return v;

     * Sets parallel by index.
     * @param f the f
    public void setParallelByIndex(@Nonnull final IntToDoubleFunction f) {
        IntStream.range(0, length()).parallel().forEach(c -> set(c, f.applyAsDouble(c)));

     * To gray png buffered png.
     * @return the buffered png
    public BufferedImage toGrayImage() {
        return toGrayImage(0);

     * To gray png buffered png.
     * @param band the band
     * @return the buffered png
    public BufferedImage toGrayImage(final int band) {
        final int width = getDimensions()[0];
        final int height = getDimensions()[1];
        final BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                final double v = get(x, y, band);
                image.getRaster().setSample(x, y, 0, v < 0 ? 0 : v > 255 ? 255 : v);
        return image;

     * To png buffered png.
     * @return the buffered png
    public BufferedImage toImage() {
        final int[] dims = getDimensions();
        if (3 == dims.length) {
            if (3 == dims[2]) {
                return toRgbImage();
            } else {
                assert 1 == dims[2];
                return toGrayImage();
        } else {
            assert 2 == dims.length;
            return toGrayImage();

     * To images list.
     * @return the list
    public List<BufferedImage> toImages() {
        final int[] dims = getDimensions();
        if (3 == dims.length) {
            if (3 == dims[2]) {
                return Arrays.asList(toRgbImage());
            } else if (0 == dims[2] % 3) {
                final ArrayList<BufferedImage> list = new ArrayList<>();
                for (int i = 0; i < dims[2]; i += 3) {
                    list.add(toRgbImage(i, i + 1, i + 2));
                return list;
            } else if (1 == dims[2]) {
                return Arrays.asList(toGrayImage());
            } else {
                final ArrayList<BufferedImage> list = new ArrayList<>();
                for (int i = 0; i < dims[2]; i++) {
                return list;
        } else {
            assert 2 == dims.length : "order: " + dims.length;
            return Arrays.asList(toGrayImage());

     * To json json element.
     * @param resources      the resources
     * @param dataSerializer the data serializer
     * @return the json element
    public JsonElement toJson(@Nullable Map<CharSequence, byte[]> resources,
            @Nonnull DataSerializer dataSerializer) {
        if (length() > 1024) {
            JsonObject obj = new JsonObject();
            int[] dimensions = getDimensions();
            obj.add("length", toJsonArray(dimensions));
            if (null != id)
                obj.addProperty("id", id.toString());
            byte[] bytes = getBytes(dataSerializer);
            obj.addProperty("precision", ((SerialPrecision) dataSerializer).name());
            if (null != resources) {
                String id = UUID.randomUUID().toString();
                obj.addProperty("resource", id);
                resources.put(id, bytes);
            } else {
                obj.addProperty("base64", Base64.getEncoder().encodeToString(bytes));
            return obj;
        } else {
            return toJson(new int[] {});

     * Sets bytes.
     * @param bytes the bytes
     * @return the bytes
    public Tensor setBytes(byte[] bytes) {
        return setBytes(bytes, json_precision);

     * Get bytes byte [ ].
     * @param precision the precision
     * @return the byte [ ]
    public byte[] getBytes(@Nonnull DataSerializer precision) {
        return precision.toBytes(getData());

     * Sets bytes.
     * @param bytes     the bytes
     * @param precision the precision
     * @return the bytes
    public Tensor setBytes(byte[] bytes, @Nonnull DataSerializer precision) {
        precision.copy(bytes, getData());
        return this;

    private JsonElement toJson(@Nonnull final int[] coords) {
        if (coords.length == dimensions.length) {
            final double d = get(coords);
            return new JsonPrimitive(d);
        } else {
            final JsonArray jsonArray = new JsonArray();
            IntStream.range(0, dimensions[dimensions.length - (coords.length + 1)]).mapToObj(i -> {
                final int[] newCoord = new int[coords.length + 1];
                System.arraycopy(coords, 0, newCoord, 1, coords.length);
                newCoord[0] = i;
                return toJson(newCoord);
            }).forEach(l -> jsonArray.add(l));
            return jsonArray;

     * To rgb png buffered png.
     * @return the buffered png
    public BufferedImage toRgbImage() {
        return toRgbImage(0, 1, 2);

     * To rgb png buffered png.
     * @param redBand   the red band
     * @param greenBand the green band
     * @param blueBand  the blue band
     * @return the buffered png
    public BufferedImage toRgbImage(final int redBand, final int greenBand, final int blueBand) {
        final int[] dims = getDimensions();
        final BufferedImage img = new BufferedImage(dims[0], dims[1], BufferedImage.TYPE_INT_RGB);
        for (int x = 0; x < img.getWidth(); x++) {
            for (int y = 0; y < img.getHeight(); y++) {
                if (getDimensions()[2] == 1) {
                    final double value = this.get(x, y, 0);
                    img.setRGB(x, y, Tensor.bound8bit((int) value) * 0x010101);
                } else {
                    final double red = Tensor.bound8bit(this.get(x, y, redBand));
                    final double green = Tensor.bound8bit(this.get(x, y, greenBand));
                    final double blue = Tensor.bound8bit(this.get(x, y, blueBand));
                    img.setRGB(x, y, (int) (red + ((int) green << 8) + ((int) blue << 16)));
        return img;

     * To rgb png buffered png.
     * @param redBand   the red band
     * @param greenBand the green band
     * @param blueBand  the blue band
     * @param alphaMask the alphaList mask
     * @return the buffered png
    public BufferedImage toRgbImageAlphaMask(final int redBand, final int greenBand, final int blueBand,
            Tensor alphaMask) {
        assert alphaMask.getDimensions()[0] == getDimensions()[0];
        assert alphaMask.getDimensions()[1] == getDimensions()[1];
        final int[] dims = getDimensions();
        final BufferedImage img = new BufferedImage(dims[0], dims[1], BufferedImage.TYPE_INT_ARGB);
        for (int x = 0; x < img.getWidth(); x++) {
            for (int y = 0; y < img.getHeight(); y++) {
                final double red = Tensor.bound8bit(this.get(x, y, redBand));
                final double green = Tensor.bound8bit(this.get(x, y, greenBand));
                final double blue = Tensor.bound8bit(this.get(x, y, blueBand));
                final double alpha = Tensor.bound8bit(alphaMask.get(x, y, 0));
                img.setRGB(x, y, (int) (red + ((int) green << 8) + ((int) blue << 16) + ((int) alpha << 24)));
        return img;

    public String toString() {
        return (null == data ? "0" : Integer.toHexString(System.identityHashCode(data))) + "@" + toString(false);

    private String toString(final boolean prettyPrint, @Nonnull final int... coords) {
        if (coords.length == dimensions.length) {
            return Double.toString(get(coords));
        } else {
            List<CharSequence> list = IntStream.range(0, dimensions[coords.length]).mapToObj(i -> {
                final int[] newCoord = Arrays.copyOf(coords, coords.length + 1);
                newCoord[coords.length] = i;
                return toString(prettyPrint, newCoord);
            if (list.size() > 10) {
                list = list.subList(0, 8);
            if (prettyPrint) {
                if (coords.length < dimensions.length - 2) {
                    final CharSequence str =
                            .map(s -> "\t" + s.toString().replaceAll("\n", "\n\t")).reduce((a, b) -> a + ",\n" + b)
                    return "[\n" + str + "\n]";
                } else {
                    final CharSequence str =, b) -> a + ", " + b).orElse("");
                    return "[ " + str + " ]";
            } else {
                final CharSequence str =, b) -> a + "," + b).orElse("");
                return "[ " + str + " ]";

     * Reverse dimensions tensor.
     * @return the tensor
    public Tensor reverseDimensions() {
        return reverseDimensions(this);

     * Permute dimensions tensor.
     * @param key the key
     * @return the tensor
    public Tensor permuteDimensions(int... key) {
        int[] inputDims = getDimensions();
        int[] absKey = -> a == Integer.MAX_VALUE ? 0 : Math.abs(a)).toArray();
        int[] outputDims = permute(absKey, inputDims, inputDims);
        return rearrange(in -> permute(key, in, inputDims), outputDims);

     * Permute dimensions and free tensor.
     * @param key the key
     * @return the tensor
    public Tensor permuteDimensionsAndFree(int... key) {
        Tensor result = permuteDimensions(key);
        return result;

     * Reshape cast tensor.
     * @param dims the dims
     * @return the tensor
    public Tensor reshapeCast(@Nonnull int... dims) {
        if (0 == dims.length)
            throw new IllegalArgumentException();
        if (length(dims) != length())
            throw new IllegalArgumentException(Arrays.toString(dims) + " != " + length());
        double[] data = getData();
        return new Tensor(dims, null == data ? null : RecycleBin.DOUBLES.copyOf(data, data.length));

     * Reshape cast and free tensor.
     * @param dims the dims
     * @return the tensor
    public Tensor reshapeCastAndFree(@Nonnull int... dims) {
        Tensor tensor = reshapeCast(dims);
        return tensor;

     * For each.
     * @param fn       the fn
     * @param parallel the parallel
    public void forEach(@Nonnull CoordOperator fn, boolean parallel) {
        coordStream(parallel).forEach(c -> {
            fn.eval(get(c), c);

     * Dot double.
     * @param right the right
     * @return the double
    public double dot(final Tensor right) {
        double[] l = getData();
        double[] r = right.getData();
        double v = 0;
        for (int i = 0; i < l.length; i++) {
            v += l[i] * r[i];
        return v;

     * Unit tensor.
     * @return the tensor
    public Tensor unit() {
        return scale(1.0 / Math.sqrt(sumSq()));

     * Select band tensor.
     * @param band the band
     * @return the tensor
    public Tensor selectBand(final int band) {
        assert band >= 0;
        int[] dimensions = getDimensions();
        assert 3 == dimensions.length;
        assert band < dimensions[2];
        return new Tensor(dimensions[0], dimensions[1], 1).setByCoord(c -> {
            int[] coords = c.getCoords();
            return get(coords[0], coords[1], band);

     * To image and free buffered image.
     * @return the buffered image
    public BufferedImage toImageAndFree() {
        BufferedImage image = toImage();
        return image;

     * Copy and free tensor.
     * @return the tensor
    public Tensor copyAndFree() {
        if (currentRefCount() == 1)
            return this;
        Tensor copy = copy();
        return copy;

     * Resize as img tensor.
     * @param width  the width
     * @param height the height
     * @return the tensor
    public Tensor resizeAsImg(final int width, final int height) {
        if (getDimensions()[0] == width && getDimensions()[1] == height) {
            return this;
        return Tensor.fromRGB(TestUtil.resize(toImage(), width, height));

    public UUID getId() {
        if (id == null) {
            synchronized (this) {
                if (id == null) {
                    id = UUID.randomUUID();
        return id;

    public Tensor setId(@Nullable UUID id) { = id;
        return this;

     * The interface Coord operator.
    public interface CoordOperator {
         * Eval double.
         * @param value the value
         * @param index the index
        void eval(double value, Coordinate index);

     * The interface Tuple operator.
    public interface TupleOperator {
         * Eval double.
         * @param value the value
         * @param index the index
         * @return the double
        double eval(double value, int index);