/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.data;

import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class DensityTree {
    private final CharSequence[] columnNames;
    private double minSplitFract = 0.05;
    private int splitSizeThreshold = 10;
    private double minFitness = 4.0;
    private int maxDepth = Integer.MAX_VALUE;

    public DensityTree(CharSequence ... columnNames) {
        this.columnNames = columnNames;
    }

    @Nonnull
    public Bounds getBounds(@Nonnull double[][] points) {
        int dim = points[0].length;
        double[] max = IntStream.range(0, dim).mapToDouble(d -> Arrays.stream(points).mapToDouble(pt -> pt[d]).filter(x -> Double.isFinite(x)).max().orElse(Double.NaN)).toArray();
        double[] min = IntStream.range(0, dim).mapToDouble(d -> Arrays.stream(points).mapToDouble(pt -> pt[d]).filter(x -> Double.isFinite(x)).min().orElse(Double.NaN)).toArray();
        return new Bounds(max, min);
    }

    public double getMinSplitFract() {
        return this.minSplitFract;
    }

    @Nonnull
    public DensityTree setMinSplitFract(double minSplitFract) {
        this.minSplitFract = minSplitFract;
        return this;
    }

    public int getSplitSizeThreshold() {
        return this.splitSizeThreshold;
    }

    @Nonnull
    public DensityTree setSplitSizeThreshold(int splitSizeThreshold) {
        this.splitSizeThreshold = splitSizeThreshold;
        return this;
    }

    public CharSequence[] getColumnNames() {
        return this.columnNames;
    }

    public double getMinFitness() {
        return this.minFitness;
    }

    @Nonnull
    public DensityTree setMinFitness(double minFitness) {
        this.minFitness = minFitness;
        return this;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    @Nonnull
    public DensityTree setMaxDepth(int maxDepth) {
        this.maxDepth = maxDepth;
        return this;
    }

    public class Node {
        @Nonnull
        public final double[][] points;
        @Nonnull
        public final Bounds bounds;
        private final int depth;
        @Nullable
        private Node left = null;
        @Nullable
        private Node right = null;
        @Nullable
        private Rule rule = null;

        public Node(double[][] points) {
            this(points, 0);
        }

        public Node(double[][] points, int depth) {
            this.points = points;
            this.bounds = DensityTree.this.getBounds(points);
            this.depth = depth;
            this.split();
        }

        public int predict(double[] pt) {
            if (null == this.rule) {
                return 0;
            }
            if (this.rule.eval(pt)) {
                return 1 + 2 * this.left.predict(pt);
            }
            return 0 + 2 * this.right.predict(pt);
        }

        public String toString() {
            return this.code();
        }

        public String code() {
            if (null != this.rule) {
                return String.format("// %s\nif(%s) { // Fitness %s\n  %s\n} else {\n  %s\n}", this.dataInfo(), this.rule, this.rule.fitness, this.left.code().replaceAll("\n", "\n  "), this.right.code().replaceAll("\n", "\n  "));
            }
            return "// " + this.dataInfo();
        }

        private CharSequence dataInfo() {
            return String.format("Count: %s Volume: %s Region: %s", this.points.length, this.bounds.getVolume(), this.bounds);
        }

        public void split() {
            if (this.points.length <= DensityTree.this.splitSizeThreshold) {
                return;
            }
            if (DensityTree.this.maxDepth <= this.depth) {
                return;
            }
            this.rule = IntStream.range(0, this.points[0].length).mapToObj(x -> x).flatMap(dim -> this.split_ortho((int)dim)).filter(x -> Double.isFinite(x.fitness)).max(Comparator.comparing(x -> x.fitness)).orElse(null);
            if (null == this.rule) {
                return;
            }
            double[][] leftPts = (double[][])Arrays.stream(this.points).filter(pt -> this.rule.eval((double[])pt)).toArray(i -> new double[i][]);
            double[][] rightPts = (double[][])Arrays.stream(this.points).filter(pt -> !this.rule.eval((double[])pt)).toArray(i -> new double[i][]);
            assert (leftPts.length + rightPts.length == this.points.length);
            if (rightPts.length == 0 || leftPts.length == 0) {
                return;
            }
            this.left = new Node(leftPts, this.depth + 1);
            this.right = new Node(rightPts, this.depth + 1);
        }

        public Stream<Rule> split_ortho(int dim) {
            double[][] sortedPoints = (double[][])Arrays.stream(this.points).filter(pt -> Double.isFinite(pt[dim])).sorted(Comparator.comparing(pt -> pt[dim])).toArray(i -> new double[i][]);
            if (0 == sortedPoints.length) {
                return Stream.empty();
            }
            int minSize = (int)Math.max((double)sortedPoints.length * DensityTree.this.minSplitFract, 1.0);
            Bounds[] left = new Bounds[sortedPoints.length];
            Bounds[] right = new Bounds[sortedPoints.length];
            left[0] = DensityTree.this.getBounds(new double[][]{sortedPoints[0]});
            right[sortedPoints.length - 1] = DensityTree.this.getBounds(new double[][]{sortedPoints[sortedPoints.length - 1]});
            for (int i2 = 1; i2 < sortedPoints.length; ++i2) {
                left[i2] = left[i2 - 1].union(sortedPoints[i2]);
                right[sortedPoints.length - 1 - i2] = right[sortedPoints.length - 1 - (i2 - 1)].union(sortedPoints[sortedPoints.length - 1 - i2]);
            }
            return IntStream.range(1, sortedPoints.length - 1).filter(i -> sortedPoints[i - 1][dim] < sortedPoints[i][dim]).mapToObj(i -> {
                int leftCount = i;
                int rightCount = sortedPoints.length - leftCount;
                if (minSize >= leftCount || minSize >= rightCount) {
                    return null;
                }
                OrthoRule rule = new OrthoRule(dim, sortedPoints[i][dim]);
                Bounds l = left[i - 1];
                Bounds r = right[i];
                rule.fitness = -((double)leftCount * Math.log(l.getVolume() / this.bounds.getVolume()) + (double)rightCount * Math.log(r.getVolume() / this.bounds.getVolume())) / ((double)sortedPoints.length * Math.log(2.0));
                return rule;
            }).filter(i -> null != i && i.fitness > DensityTree.this.minFitness);
        }

        @Nullable
        public Rule getRule() {
            return this.rule;
        }

        @Nonnull
        protected Node setRule(Rule rule) {
            this.rule = rule;
            return this;
        }

        @Nullable
        public Node getRight() {
            return this.right;
        }

        @Nonnull
        protected Node setRight(Node right) {
            this.right = right;
            return this;
        }

        @Nullable
        public Node getLeft() {
            return this.left;
        }

        @Nonnull
        protected Node setLeft(Node left) {
            this.left = left;
            return this;
        }

        public int getDepth() {
            return this.depth;
        }
    }

    public abstract class Rule {
        public final String name;
        public double fitness;

        public Rule(String name) {
            this.name = name;
        }

        public abstract boolean eval(double[] var1);

        public String toString() {
            return this.name;
        }
    }

    public class OrthoRule
    extends Rule {
        private final int dim;
        private final double value;

        public OrthoRule(int dim, double value) {
            super(String.format("%s < %s", DensityTree.this.columnNames[dim], value));
            this.dim = dim;
            this.value = value;
        }

        @Override
        public boolean eval(double[] pt) {
            return pt[this.dim] < this.value;
        }
    }

    public class Bounds {
        @Nonnull
        public final double[] max;
        @Nonnull
        public final double[] min;

        public Bounds(@Nonnull double[] max, double[] min) {
            this.max = max;
            this.min = min;
            assert (max.length == min.length);
            assert (IntStream.range(0, max.length).filter(i -> Double.isFinite(max[i])).allMatch(i -> max[i] >= min[i]));
        }

        @Nonnull
        public Bounds union(@Nonnull double[] pt) {
            int dim = pt.length;
            return new Bounds(IntStream.range(0, dim).mapToDouble(d -> Double.isFinite(pt[d]) ? Math.max(this.max[d], pt[d]) : this.max[d]).toArray(), IntStream.range(0, dim).mapToDouble(d -> Double.isFinite(pt[d]) ? Math.min(this.min[d], pt[d]) : this.min[d]).toArray());
        }

        public double getVolume() {
            int dim = this.min.length;
            return IntStream.range(0, dim).mapToDouble(d -> this.max[d] - this.min[d]).filter(x -> Double.isFinite(x) && x > 0.0).reduce((a, b) -> a * b).orElse(Double.NaN);
        }

        @Nonnull
        public String toString() {
            return "[" + IntStream.range(0, this.min.length).mapToObj(d -> String.format("%s: %s - %s", DensityTree.this.columnNames[d], this.min[d], this.max[d])).reduce((a, b) -> a + "; " + b).get() + "]";
        }
    }
}

