/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.text;

import com.simiacryptus.util.text.CharTrieIndex;
import com.simiacryptus.util.text.IndexNode;
import java.io.PrintStream;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ClassificationTree {
    private PrintStream verbose = null;
    private double minLeafWeight = 10.0;
    private int maxLevels = 8;
    private int minWeight = 5;
    private double depthBias = 5.0E-4;
    private int smoothing = 3;

    public Function<String, Map<String, Double>> categorizationTree(Map<String, List<String>> categories, int depth) {
        return this.categorizationTree(categories, depth, "");
    }

    private Function<String, Map<String, Double>> categorizationTree(Map<String, List<String>> categories, int depth, String indent) {
        if (0 == depth) {
            return str -> {
                int sum = categories.values().stream().mapToInt(x -> x.size()).sum();
                return categories.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> (double)((List)e.getValue()).size() * 1.0 / (double)sum));
            };
        }
        if (1L >= categories.values().stream().filter(x -> !x.isEmpty()).count()) {
            return this.categorizationTree(categories, 0, indent);
        }
        Optional<NodeInfo> info = this.categorizationSubstring(categories.values());
        if (!info.isPresent()) {
            return this.categorizationTree(categories, 0, indent);
        }
        String split = info.get().node.getString();
        Map<String, List<String>> lSet = categories.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> ((List)e.getValue()).stream().filter(str -> str.contains(split)).collect(Collectors.toList())));
        Map<String, List<String>> rSet = categories.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> ((List)e.getValue()).stream().filter(str -> !str.contains(split)).collect(Collectors.toList())));
        int lSum = lSet.values().stream().mapToInt(x -> x.size()).sum();
        int rSum = rSet.values().stream().mapToInt(x -> x.size()).sum();
        if (0 == lSum || 0 == rSum) {
            return this.categorizationTree(categories, 0, indent);
        }
        if (null != this.verbose) {
            this.verbose.println(String.format(indent + "\"%s\" -> Contains=%s\tAbsent=%s\tEntropy=%5f", split, lSet.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> ((List)e.getValue()).size())), rSet.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> ((List)e.getValue()).size())), info.get().entropy));
        }
        Function<String, Map<String, Double>> l = this.categorizationTree(lSet, depth - 1, indent + "  ");
        Function<String, Map<String, Double>> r = this.categorizationTree(rSet, depth - 1, indent + "  ");
        return str -> {
            if (str.contains(split)) {
                return (Map)l.apply((String)str);
            }
            return (Map)r.apply((String)str);
        };
    }

    private double entropy(Map<Integer, Long> sum, Map<Integer, Long> left) {
        double leftSum;
        double sumSum = sum.values().stream().mapToDouble(x -> x.longValue()).sum();
        double rightSum = sumSum - (leftSum = left.values().stream().mapToDouble(x -> x.longValue()).sum());
        if (rightSum < this.minLeafWeight) {
            return Double.NEGATIVE_INFINITY;
        }
        if (leftSum < this.minLeafWeight) {
            return Double.NEGATIVE_INFINITY;
        }
        return (sum.keySet().stream().mapToDouble(category -> {
            Long leftCnt = left.getOrDefault(category, 0L);
            return (double)leftCnt.longValue() * Math.log((double)(leftCnt + (long)this.smoothing) * 1.0 / (leftSum + (double)(this.smoothing * sum.size())));
        }).sum() + sum.keySet().stream().mapToDouble(category -> {
            Long rightCnt = sum.getOrDefault(category, 0L) - left.getOrDefault(category, 0L);
            return (double)rightCnt.longValue() * Math.log((double)(rightCnt + (long)this.smoothing) * 1.0 / (rightSum + (double)(this.smoothing * sum.size())));
        }).sum()) / (sumSum * Math.log(2.0));
    }

    private Optional<NodeInfo> categorizationSubstring(Collection<List<String>> categories) {
        CharTrieIndex trie = new CharTrieIndex();
        TreeMap<Integer, Integer> categoryMap = new TreeMap<Integer, Integer>();
        int categoryNumber = 0;
        HashMap<Integer, Long> sum = new HashMap();
        for (List<String> category : categories) {
            ++categoryNumber;
            for (String text : category) {
                sum.put(categoryNumber, sum.getOrDefault(categoryNumber, 0L) + (long)text.length() + 1L);
                categoryMap.put(trie.addDocument(text), categoryNumber);
            }
        }
        trie.index(this.maxLevels, this.minWeight);
        sum = this.summarize(trie.root(), categoryMap);
        return this.categorizationSubstring(trie.root(), categoryMap, sum);
    }

    public PrintStream getVerbose() {
        return this.verbose;
    }

    public ClassificationTree setVerbose(PrintStream verbose) {
        this.verbose = verbose;
        return this;
    }

    private NodeInfo info(IndexNode node, Map<Integer, Long> sum, Map<Integer, Integer> categoryMap) {
        Map<Integer, Long> summary = this.summarize(node, categoryMap);
        return new NodeInfo(node, summary, this.entropy(sum, summary));
    }

    private Map<Integer, Long> summarize(IndexNode node, Map<Integer, Integer> categoryMap) {
        return node.getCursors().map(x -> x.getDocumentId()).distinct().map(x -> (Integer)categoryMap.get(x)).collect(Collectors.toList()).stream().collect(Collectors.groupingBy(x -> x, Collectors.counting()));
    }

    private Optional<NodeInfo> categorizationSubstring(IndexNode node, Map<Integer, Integer> categoryMap, Map<Integer, Long> sum) {
        List childrenInfo = node.getChildren().map(n -> this.categorizationSubstring((IndexNode)n, categoryMap, sum)).filter(x -> x.isPresent()).map(x -> (NodeInfo)x.get()).collect(Collectors.toList());
        NodeInfo info = this.info(node, sum, categoryMap);
        if (info.node.getString().isEmpty() || !Double.isFinite(info.entropy)) {
            info = null;
        }
        Optional<NodeInfo> max = Stream.concat(null == info ? Stream.empty() : Stream.of(info), childrenInfo.stream()).max(Comparator.comparing(x -> x.entropy));
        return max;
    }

    private class NodeInfo {
        IndexNode node;
        Map<Integer, Long> categoryWeights;
        double entropy;

        public NodeInfo(IndexNode node, Map<Integer, Long> categoryWeights, double entropy) {
            this.node = node;
            this.categoryWeights = categoryWeights;
            this.entropy = entropy + ClassificationTree.this.depthBias * (double)node.getDepth();
        }
    }
}

