/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.ml;

import com.simiacryptus.util.ml.Coordinate;
import java.awt.image.BufferedImage;
import java.lang.ref.SoftReference;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Spliterators;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.DoubleSupplier;
import java.util.function.DoubleUnaryOperator;
import java.util.function.IntToDoubleFunction;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class Tensor {
    private static ConcurrentHashMap<Integer, BlockingQueue<SoftReference<double[]>>> recycling = new ConcurrentHashMap();
    private volatile double[] data;
    private volatile SoftReference<double[]> dataSoftref;
    protected final int[] dims;
    protected final int[] skips;

    public static int dim(int ... dims) {
        int total = 1;
        for (int dim : dims) {
            total *= dim;
        }
        return total;
    }

    protected Tensor() {
        this.data = null;
        this.skips = null;
        this.dims = null;
    }

    public void finalize() throws Throwable {
        if (null != this.data && this.data.length < 0x100000) {
            BlockingQueue<SoftReference<double[]>> bin = recycling.get(this.data.length);
            if (null == bin) {
                bin = new ArrayBlockingQueue<SoftReference<double[]>>(1000);
                recycling.put(this.data.length, bin);
            }
            if (null == this.dataSoftref) {
                bin.offer(new SoftReference<double[]>(this.data));
            } else {
                bin.offer(this.dataSoftref);
            }
            this.data = null;
        }
        super.finalize();
    }

    public Tensor(int ... dims) {
        this(dims, (double[])null);
    }

    public static Tensor fromRGB(BufferedImage img) {
        Tensor a = new Tensor(img.getWidth(), img.getHeight(), 3);
        for (int x = 0; x < img.getWidth(); ++x) {
            for (int y = 0; y < img.getHeight(); ++y) {
                a.set(new int[]{x, y, 0}, (double)(img.getRGB(x, y) & 0xFF));
                a.set(new int[]{x, y, 1}, (double)(img.getRGB(x, y) >> 8 & 0xFF));
                a.set(new int[]{x, y, 2}, (double)(img.getRGB(x, y) >> 16 & 0xFF));
            }
        }
        return a;
    }

    public Tensor(int[] dims, double[] data) {
        this.dims = Arrays.copyOf(dims, dims.length);
        this.skips = new int[dims.length];
        for (int i = 0; i < this.skips.length; ++i) {
            this.skips[i] = i == 0 ? 1 : this.skips[i - 1] * dims[i - 1];
        }
        assert (null == data || Tensor.dim(dims) == data.length);
        assert (null == data || 0 < data.length);
        this.data = data;
    }

    public Tensor(double[] ds) {
        this(new int[]{ds.length}, ds);
    }

    private int[] _add(int[] base, int ... extra) {
        int[] copy = Arrays.copyOf(base, base.length + extra.length);
        for (int i = 0; i < extra.length; ++i) {
            copy[i + base.length] = extra[i];
        }
        return copy;
    }

    public void add(Coordinate coords, double value) {
        this.add(coords.index, value);
    }

    public final Tensor add(int index, double value) {
        double[] dArray = this.getData();
        int n = index;
        dArray[n] = dArray[n] + value;
        return this;
    }

    public void add(int[] coords, double value) {
        this.add(this.index(coords), value);
    }

    public Stream<Coordinate> coordStream() {
        return this.coordStream(false);
    }

    public Stream<Coordinate> coordStream(boolean paralell) {
        return StreamSupport.stream(Spliterators.spliterator(new Iterator<Coordinate>(){
            int cnt = 0;
            int[] val;
            {
                this.val = new int[Tensor.this.dims.length];
            }

            @Override
            public boolean hasNext() {
                return this.cnt < Tensor.this.dim();
            }

            @Override
            public Coordinate next() {
                int[] last = Arrays.copyOf(this.val, this.val.length);
                for (int i = 0; i < this.val.length; ++i) {
                    int n = i;
                    this.val[n] = this.val[n] + 1;
                    if (this.val[n] < Tensor.this.dims[i]) break;
                    this.val[i] = 0;
                }
                int index = this.cnt++;
                return new Coordinate(index, last);
            }
        }, (long)this.dim(), 16), paralell);
    }

    public Tensor copy() {
        return new Tensor(Arrays.copyOf(this.dims, this.dims.length), Arrays.copyOf(this.getData(), this.getData().length));
    }

    public int dim() {
        return this.getData().length;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        Tensor other = (Tensor)obj;
        if (!Arrays.equals(this.getData(), other.getData())) {
            return false;
        }
        return Arrays.equals(this.dims, other.dims);
    }

    public Tensor fill(DoubleSupplier f) {
        Arrays.parallelSetAll(this.getData(), i -> f.getAsDouble());
        return this;
    }

    public Tensor fill(IntToDoubleFunction f) {
        Arrays.parallelSetAll(this.getData(), i -> f.applyAsDouble(i));
        return this;
    }

    public double get(Coordinate coords) {
        double v = this.getData()[coords.index];
        return v;
    }

    public double get(int ... coords) {
        double v = this.getData()[this.index(coords)];
        return v;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final double[] getData() {
        if (null == this.data) {
            Tensor tensor = this;
            synchronized (tensor) {
                if (null == this.data) {
                    SoftReference poll;
                    int length = Tensor.dim(this.dims);
                    BlockingQueue bin = recycling.getOrDefault(length, new LinkedBlockingDeque());
                    double[] data = null;
                    while (null != (poll = (SoftReference)bin.poll())) {
                        data = (double[])poll.get();
                        if (null != data) {
                            Arrays.fill(data, 0.0);
                            break;
                        }
                        if (null != poll && null == data) continue;
                    }
                    if (null == data) {
                        this.data = new double[length];
                        this.dataSoftref = null;
                    } else {
                        this.data = data;
                        this.dataSoftref = poll;
                    }
                }
            }
        }
        return this.data;
    }

    public final int[] getDims() {
        return this.dims;
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + Arrays.hashCode(this.getData());
        result = 31 * result + Arrays.hashCode(this.dims);
        return result;
    }

    public int index(Coordinate coords) {
        return coords.index;
    }

    public int index(int ... coords) {
        int v = 0;
        for (int i = 0; i < this.skips.length && i < coords.length; ++i) {
            v += this.skips[i] * coords[i];
        }
        return v;
    }

    public double l1() {
        return Arrays.stream(this.getData()).sum();
    }

    public double l2() {
        return Math.sqrt(Arrays.stream(this.getData()).map((double x) -> x * x).sum());
    }

    public Tensor map(DoubleUnaryOperator f) {
        double[] cpy = new double[this.getData().length];
        for (int i = 0; i < this.getData().length; ++i) {
            double v;
            double x = this.getData()[i];
            cpy[i] = v = f.applyAsDouble(x);
        }
        return new Tensor(this.dims, cpy);
    }

    public Tensor map(ToDoubleBiFunction<Double, Coordinate> f) {
        return new Tensor(this.dims, this.coordStream(false).mapToDouble(i -> f.applyAsDouble(this.get((Coordinate)i), (Coordinate)i)).toArray());
    }

    public Tensor minus(Tensor right) {
        assert (Arrays.equals(this.getDims(), right.getDims()));
        Tensor copy = new Tensor(this.getDims());
        double[] thisData = this.getData();
        double[] rightData = right.getData();
        Arrays.parallelSetAll(copy.getData(), i -> thisData[i] - rightData[i]);
        return copy;
    }

    public Tensor reformat(int ... dims) {
        return new Tensor(dims, this.getData());
    }

    public Tensor multiply(double d) {
        Tensor tensor = new Tensor(this.getDims());
        double[] resultData = tensor.getData();
        double[] thisData = this.getData();
        for (int i = 0; i < thisData.length; ++i) {
            resultData[i] = d * thisData[i];
        }
        return tensor;
    }

    public Tensor scale(double d) {
        int i = 0;
        while (i < this.getData().length) {
            double[] dArray = this.getData();
            int n = i++;
            dArray[n] = dArray[n] * d;
        }
        return this;
    }

    public void set(Coordinate coords, double value) {
        assert (Double.isFinite(value));
        this.set(coords.index, value);
    }

    public Tensor set(double[] data) {
        for (int i = 0; i < this.getData().length; ++i) {
            this.getData()[i] = data[i];
        }
        return this;
    }

    public BufferedImage toGrayImage() {
        int width = this.getDims()[0];
        int height = this.getDims()[1];
        BufferedImage image = new BufferedImage(width, height, 10);
        for (int x = 0; x < width; ++x) {
            for (int y = 0; y < height; ++y) {
                double v = this.get(x, y);
                image.getRaster().setSample(x, y, 0, v < 0.0 ? 0.0 : v);
            }
        }
        return image;
    }

    private static int bounds(int value) {
        int max = 255;
        boolean min = false;
        return value < 0 ? 0 : (value > 255 ? 255 : value);
    }

    private static double bounds(double value) {
        int max = 255;
        boolean min = false;
        return value < 0.0 ? 0.0 : (value > 255.0 ? 255.0 : value);
    }

    public BufferedImage toRgbImage() {
        int[] dims = this.getDims();
        BufferedImage img = new BufferedImage(dims[0], dims[1], 1);
        for (int x = 0; x < img.getWidth(); ++x) {
            for (int y = 0; y < img.getHeight(); ++y) {
                if (this.getDims()[2] == 1) {
                    double value = this.get(x, y, 0);
                    img.setRGB(x, y, Tensor.bounds((int)value) * 65793);
                    continue;
                }
                double red = Tensor.bounds(this.get(x, y, 0));
                double green = Tensor.bounds(this.get(x, y, 1));
                double blue = Tensor.bounds(this.get(x, y, 2));
                img.setRGB(x, y, (int)(red + (double)((int)green << 8) + (double)((int)blue << 16)));
            }
        }
        return img;
    }

    public Tensor set(int index, double value) {
        this.getData()[index] = value;
        return this;
    }

    public void set(int[] coords, double value) {
        assert (Double.isFinite(value));
        this.set(this.index(coords), value);
    }

    public void set(Tensor right) {
        assert (this.dim() == right.dim());
        double[] rightData = right.getData();
        Arrays.parallelSetAll(this.getData(), i -> rightData[i]);
    }

    public double sum() {
        double v = 0.0;
        for (double element : this.getData()) {
            v += element;
        }
        return v;
    }

    public String toString() {
        return this.toString(new int[0]);
    }

    private String toString(int ... coords) {
        if (coords.length == this.dims.length) {
            return Double.toString(this.get(coords));
        }
        List list = IntStream.range(0, this.dims[coords.length]).mapToObj(i -> this.toString(this._add(coords, i))).collect(Collectors.toList());
        if (list.size() > 10) {
            list = list.subList(0, 8);
            list.add("...");
        }
        Optional str = list.stream().limit(10L).reduce((a, b) -> a + "," + b);
        return "[ " + (String)str.get() + " ]";
    }

    public void setAll(double v) {
        double[] data = this.getData();
        for (int i = 0; i < data.length; ++i) {
            data[i] = v;
        }
    }
}

