/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.data;

import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefComparator;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefStream;
import com.simiacryptus.ref.wrappers.RefString;
import java.util.Comparator;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class DensityTree {
    private final CharSequence[] columnNames;
    private double minSplitFract = 0.05;
    private int splitSizeThreshold = 10;
    private double minFitness = 4.0;
    private int maxDepth = Integer.MAX_VALUE;

    public DensityTree(CharSequence ... columnNames) {
        this.columnNames = columnNames;
    }

    public CharSequence[] getColumnNames() {
        return this.columnNames;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public @Nonnull DensityTree setMaxDepth(int maxDepth) {
        this.maxDepth = maxDepth;
        return this;
    }

    public double getMinFitness() {
        return this.minFitness;
    }

    public @Nonnull DensityTree setMinFitness(double minFitness) {
        this.minFitness = minFitness;
        return this;
    }

    public double getMinSplitFract() {
        return this.minSplitFract;
    }

    public @Nonnull DensityTree setMinSplitFract(double minSplitFract) {
        this.minSplitFract = minSplitFract;
        return this;
    }

    public int getSplitSizeThreshold() {
        return this.splitSizeThreshold;
    }

    public @Nonnull DensityTree setSplitSizeThreshold(int splitSizeThreshold) {
        this.splitSizeThreshold = splitSizeThreshold;
        return this;
    }

    public @Nonnull Bounds getBounds(@Nonnull double[][] points) {
        int dim = points[0].length;
        double[] max = RefIntStream.range((int)0, (int)dim).mapToDouble(d -> RefArrays.stream((Object[])points).mapToDouble(pt -> pt[d]).filter(x -> Double.isFinite(x)).max().orElse(Double.NaN)).toArray();
        double[] min = RefIntStream.range((int)0, (int)dim).mapToDouble(d -> RefArrays.stream((Object[])points).mapToDouble(pt -> pt[d]).filter(x -> Double.isFinite(x)).min().orElse(Double.NaN)).toArray();
        return new Bounds(max, min);
    }

    public class Node {
        public final @Nonnull double[][] points;
        public final @Nonnull Bounds bounds;
        private final int depth;
        private @Nullable Node left = null;
        private @Nullable Node right = null;
        private @Nullable Rule rule = null;

        public Node(double[][] points) {
            this(points, 0);
        }

        public Node(double[][] points, int depth) {
            this.points = points;
            this.bounds = DensityTree.this.getBounds(points);
            this.depth = depth;
            this.split();
        }

        public int getDepth() {
            return this.depth;
        }

        public @Nullable Node getLeft() {
            return this.left;
        }

        protected @Nonnull Node setLeft(Node left) {
            this.left = left;
            return this;
        }

        public @Nullable Node getRight() {
            return this.right;
        }

        protected @Nonnull Node setRight(Node right) {
            this.right = right;
            return this;
        }

        public @Nullable Rule getRule() {
            return this.rule;
        }

        protected @Nonnull Node setRule(Rule rule) {
            this.rule = rule;
            return this;
        }

        public int predict(double[] pt) {
            if (null == this.rule) {
                return 0;
            }
            if (this.rule.eval(pt)) {
                assert (this.left != null);
                return 1 + 2 * this.left.predict(pt);
            }
            assert (this.right != null);
            return 0 + 2 * this.right.predict(pt);
        }

        public @Nonnull String toString() {
            return this.code();
        }

        public @Nonnull String code() {
            if (null != this.rule) {
                assert (this.right != null);
                assert (this.left != null);
                return RefString.format((String)"// %s\nif(%s) { // Fitness %s\n  %s\n} else {\n  %s\n}", (Object[])new Object[]{this.dataInfo(), this.rule, this.rule.fitness, this.left.code().replaceAll("\n", "\n  "), this.right.code().replaceAll("\n", "\n  ")});
            }
            return "// " + this.dataInfo();
        }

        public void split() {
            if (this.points.length <= DensityTree.this.splitSizeThreshold) {
                return;
            }
            if (DensityTree.this.maxDepth <= this.depth) {
                return;
            }
            this.rule = RefIntStream.range((int)0, (int)this.points[0].length).mapToObj(x -> x).flatMap(dim -> this.split_ortho((int)dim)).filter(x -> Double.isFinite(x.fitness)).max((Comparator)RefComparator.comparingDouble(x -> x.fitness)).orElse(null);
            if (null == this.rule) {
                return;
            }
            double[][] leftPts = (double[][])RefArrays.stream((Object[])this.points).filter(pt -> this.rule.eval((double[])pt)).toArray(i -> new double[i][]);
            double[][] rightPts = (double[][])RefArrays.stream((Object[])this.points).filter(pt -> !this.rule.eval((double[])pt)).toArray(i -> new double[i][]);
            assert (leftPts.length + rightPts.length == this.points.length);
            if (rightPts.length == 0 || leftPts.length == 0) {
                return;
            }
            this.left = new Node(leftPts, this.depth + 1);
            this.right = new Node(rightPts, this.depth + 1);
        }

        public @Nonnull RefStream<Rule> split_ortho(int dim) {
            double[][] sortedPoints = (double[][])RefArrays.stream((Object[])this.points).filter(pt -> Double.isFinite(pt[dim])).sorted((Comparator)RefComparator.comparingDouble(pt -> pt[dim])).toArray(i -> new double[i][]);
            if (0 == sortedPoints.length) {
                return RefStream.empty();
            }
            int minSize = (int)Math.max((double)sortedPoints.length * DensityTree.this.minSplitFract, 1.0);
            @Nonnull Bounds[] left = new Bounds[sortedPoints.length];
            @Nonnull Bounds[] right = new Bounds[sortedPoints.length];
            left[0] = DensityTree.this.getBounds(new double[][]{sortedPoints[0]});
            right[sortedPoints.length - 1] = DensityTree.this.getBounds(new double[][]{sortedPoints[sortedPoints.length - 1]});
            for (int i2 = 1; i2 < sortedPoints.length; ++i2) {
                left[i2] = left[i2 - 1].union(sortedPoints[i2]);
                right[sortedPoints.length - 1 - i2] = right[sortedPoints.length - 1 - (i2 - 1)].union(sortedPoints[sortedPoints.length - 1 - i2]);
            }
            return RefIntStream.range((int)1, (int)(sortedPoints.length - 1)).filter(i -> sortedPoints[i - 1][dim] < sortedPoints[i][dim]).mapToObj(i -> {
                int rightCount = sortedPoints.length - i;
                if (minSize >= i || minSize >= rightCount) {
                    return null;
                }
                @Nonnull OrthoRule rule = new OrthoRule(dim, sortedPoints[i][dim]);
                Bounds l = left[i - 1];
                Bounds r = right[i];
                rule.fitness = -((double)i * Math.log(l.getVolume() / this.bounds.getVolume()) + (double)rightCount * Math.log(r.getVolume() / this.bounds.getVolume())) / ((double)sortedPoints.length * Math.log(2.0));
                return rule;
            }).filter(i -> null != i && i.fitness > DensityTree.this.minFitness);
        }

        private @Nonnull CharSequence dataInfo() {
            return RefString.format((String)"Count: %s Volume: %s Region: %s", (Object[])new Object[]{this.points.length, this.bounds.getVolume(), this.bounds});
        }
    }

    public abstract class Rule {
        public final String name;
        public double fitness;

        public Rule(String name) {
            this.name = name;
        }

        public abstract boolean eval(double[] var1);

        public String toString() {
            return this.name;
        }
    }

    public class OrthoRule
    extends Rule {
        private final int dim;
        private final double value;

        public OrthoRule(int dim, double value) {
            super(RefString.format((String)"%s < %s", (Object[])new Object[]{DensityTree.this.columnNames[dim], value}));
            this.dim = dim;
            this.value = value;
        }

        @Override
        public boolean eval(double[] pt) {
            return pt[this.dim] < this.value;
        }
    }

    public class Bounds {
        public final @Nonnull double[] max;
        public final @Nonnull double[] min;

        public Bounds(@Nonnull double[] max, double[] min) {
            this.max = max;
            this.min = min;
            assert (max.length == min.length);
            assert (RefIntStream.range((int)0, (int)max.length).filter(i -> Double.isFinite(max[i])).allMatch(i -> max[i] >= min[i]));
        }

        public double getVolume() {
            int dim = this.min.length;
            return RefIntStream.range((int)0, (int)dim).mapToDouble(d -> this.max[d] - this.min[d]).filter(x -> Double.isFinite(x) && x > 0.0).reduce((a, b) -> a * b).orElse(Double.NaN);
        }

        public @Nonnull Bounds union(@Nonnull double[] pt) {
            int dim = pt.length;
            return new Bounds(RefIntStream.range((int)0, (int)dim).mapToDouble(d -> Double.isFinite(pt[d]) ? Math.max(this.max[d], pt[d]) : this.max[d]).toArray(), RefIntStream.range((int)0, (int)dim).mapToDouble(d -> Double.isFinite(pt[d]) ? Math.min(this.min[d], pt[d]) : this.min[d]).toArray());
        }

        public @Nonnull String toString() {
            return "[" + (String)RefUtil.get((Optional)RefIntStream.range((int)0, (int)this.min.length).mapToObj(d -> RefString.format((String)"%s: %s - %s", (Object[])new Object[]{DensityTree.this.columnNames[d], this.min[d], this.max[d]})).reduce((a, b) -> a + "; " + b)) + "]";
        }
    }
}

