/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.binary.bitset;

import com.google.common.collect.Maps;
import com.simiacryptus.util.binary.BitInputStream;
import com.simiacryptus.util.binary.BitOutputStream;
import com.simiacryptus.util.binary.Bits;
import com.simiacryptus.util.binary.bitset.BitsCollection;
import com.simiacryptus.util.binary.codes.Gaussian;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;

public class CountTreeBitsCollection
extends BitsCollection<TreeMap<Bits, AtomicInteger>> {
    public static boolean SERIALIZATION_CHECKS = false;
    private boolean useBinomials = true;

    public CountTreeBitsCollection() {
        super(new TreeMap());
    }

    public CountTreeBitsCollection(BitInputStream bitStream) throws IOException {
        this();
        this.read(bitStream);
    }

    public CountTreeBitsCollection(BitInputStream bitStream, int bitDepth) throws IOException {
        this(bitDepth);
        this.read(bitStream);
    }

    public CountTreeBitsCollection(byte[] data) throws IOException {
        this(BitInputStream.toBitStream(data));
    }

    public CountTreeBitsCollection(byte[] data, int bitDepth) throws IOException {
        this(BitInputStream.toBitStream(data), bitDepth);
    }

    public CountTreeBitsCollection(int bitDepth) {
        super(bitDepth, new TreeMap());
    }

    public static <T> T isNull(T value, T defaultValue) {
        return null == value ? defaultValue : value;
    }

    public TreeMap<Bits, Long> computeSums() {
        TreeMap<Bits, Long> sums = new TreeMap<Bits, Long>();
        long total = 0L;
        for (Map.Entry e : ((TreeMap)this.map).entrySet()) {
            sums.put((Bits)e.getKey(), total += (long)((AtomicInteger)e.getValue()).get());
        }
        return sums;
    }

    @Override
    public void read(BitInputStream in) throws IOException {
        this.getMap().clear();
        long size = in.readVarLong();
        if (0L < size) {
            this.read(in, Bits.NULL, size);
        }
    }

    private void read(BitInputStream in, Bits code, long size) throws IOException {
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.StartTree);
        }
        BranchCounts branchCounts = this.readBranchCounts(in, code, size);
        if (0L < branchCounts.terminals) {
            ((TreeMap)this.map).put(code, new AtomicInteger((int)branchCounts.terminals));
        }
        if (0L < branchCounts.zeroCount) {
            this.read(in, code.concatenate(Bits.ZERO), branchCounts.zeroCount);
        }
        if (branchCounts.oneCount > 0L) {
            this.read(in, code.concatenate(Bits.ONE), branchCounts.oneCount);
        }
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.EndTree);
        }
    }

    public void read(BitInputStream in, int size) throws IOException {
        this.getMap().clear();
        if (0 < size) {
            this.read(in, Bits.NULL, size);
        }
    }

    protected BranchCounts readBranchCounts(BitInputStream in, Bits code, long size) throws IOException {
        BranchCounts branchCounts = new BranchCounts(code, size);
        BitsCollection.CodeType currentCodeType = this.getType(code);
        long maximum = size;
        branchCounts.terminals = currentCodeType == BitsCollection.CodeType.Unknown ? this.readTerminalCount(in, maximum) : (currentCodeType == BitsCollection.CodeType.Terminal ? size : 0L);
        if ((maximum -= branchCounts.terminals) > 0L) {
            assert (Thread.currentThread().getStackTrace().length < 100);
            branchCounts.zeroCount = this.readZeroBranchSize(in, maximum, code);
        }
        branchCounts.oneCount = maximum -= branchCounts.zeroCount;
        return branchCounts;
    }

    protected long readTerminalCount(BitInputStream in, long size) throws IOException {
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.BeforeTerminal);
        }
        long readBoundedLong = in.readBoundedLong(1L + size);
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.AfterTerminal);
        }
        return readBoundedLong;
    }

    protected long readZeroBranchSize(BitInputStream in, long max, Bits code) throws IOException {
        if (0L == max) {
            return 0L;
        }
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.BeforeCount);
        }
        long value = this.useBinomials ? Gaussian.fromBinomial(0.5, max).decode(in, max) : in.readBoundedLong(1L + max);
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.AfterCount);
        }
        return value;
    }

    public CountTreeBitsCollection setUseBinomials(boolean useBinomials) {
        this.useBinomials = useBinomials;
        return this;
    }

    public long sum(Collection<Long> values) {
        long total = 0L;
        for (Long v : values) {
            total += v.longValue();
        }
        return total;
    }

    public byte[] toBytes() throws IOException {
        ByteArrayOutputStream outBuffer = new ByteArrayOutputStream();
        BitOutputStream out = new BitOutputStream(outBuffer);
        this.write(out);
        out.flush();
        return outBuffer.toByteArray();
    }

    public boolean useBinomials() {
        return this.useBinomials;
    }

    @Override
    public void write(BitOutputStream out) throws IOException {
        TreeMap<Bits, Long> sums = this.computeSums();
        long value = 0 == sums.size() ? 0L : sums.lastEntry().getValue();
        out.writeVarLong(value);
        if (0L < value) {
            this.write(out, Bits.NULL, sums);
        }
    }

    private void write(BitOutputStream out, Bits currentCode, NavigableMap<Bits, Long> sums) throws IOException {
        Map.Entry<Bits, Long> firstEntry = sums.firstEntry();
        NavigableMap<Bits, Long> remainder = sums.tailMap(currentCode, false);
        Bits splitCode = currentCode.concatenate(Bits.ONE);
        NavigableMap<Bits, Long> zeroMap = remainder.headMap(splitCode, false);
        NavigableMap<Bits, Long> oneMap = remainder.tailMap(splitCode, true);
        int firstEntryCount = ((AtomicInteger)((TreeMap)this.map).get(firstEntry.getKey())).get();
        long baseCount = firstEntry.getValue() - (long)firstEntryCount;
        long endCount = sums.lastEntry().getValue();
        long size = endCount - baseCount;
        long terminals = firstEntry.getKey().equals(currentCode) ? (long)firstEntryCount : 0L;
        long zeroCount = 0 == zeroMap.size() ? 0L : zeroMap.lastEntry().getValue() - baseCount - terminals;
        long oneCount = size - terminals - zeroCount;
        Maps.EntryTransformer<Bits, Long, Long> transformer = new Maps.EntryTransformer<Bits, Long, Long>(){

            public Long transformEntry(Bits key, Long value) {
                return ((AtomicInteger)((TreeMap)CountTreeBitsCollection.this.map).get(key)).get();
            }
        };
        assert (size == this.sum(Maps.transformEntries(sums, (Maps.EntryTransformer)transformer).values()));
        assert (zeroCount == this.sum(Maps.transformEntries(zeroMap, (Maps.EntryTransformer)transformer).values()));
        assert (oneCount == this.sum(Maps.transformEntries(oneMap, (Maps.EntryTransformer)transformer).values()));
        BranchCounts branchCounts = new BranchCounts(currentCode, size, terminals, zeroCount, oneCount);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.StartTree);
        }
        this.writeBranchCounts(branchCounts, out);
        if (0L < zeroCount) {
            this.write(out, currentCode.concatenate(Bits.ZERO), zeroMap);
        }
        if (0L < oneCount) {
            this.write(out, currentCode.concatenate(Bits.ONE), oneMap);
        }
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.EndTree);
        }
    }

    public void write(BitOutputStream out, int size) throws IOException {
        long value;
        TreeMap<Bits, Long> sums = this.computeSums();
        long l = value = 0 == sums.size() ? 0L : sums.lastEntry().getValue();
        if (value != (long)size) {
            throw new RuntimeException();
        }
        if (0L < value) {
            this.write(out, Bits.NULL, sums);
        }
    }

    protected void writeBranchCounts(BranchCounts branch, BitOutputStream out) throws IOException {
        BitsCollection.CodeType currentCodeType = this.getType(branch.path);
        long maximum = branch.size;
        assert (maximum >= branch.terminals);
        if (currentCodeType == BitsCollection.CodeType.Unknown) {
            this.writeTerminalCount(out, branch.terminals, maximum);
        } else if (currentCodeType == BitsCollection.CodeType.Terminal) {
            assert (branch.size == branch.terminals);
            assert (0L == branch.zeroCount);
            assert (0L == branch.oneCount);
        } else assert (currentCodeType != BitsCollection.CodeType.Prefix || 0L == branch.terminals);
        assert ((maximum -= branch.terminals) >= branch.zeroCount);
        if (0L < maximum) {
            this.writeZeroBranchSize(out, branch.zeroCount, maximum, branch.path);
            maximum -= branch.zeroCount;
        } else assert (0L == branch.zeroCount);
        assert (maximum == branch.oneCount);
    }

    protected void writeTerminalCount(BitOutputStream out, long value, long max) throws IOException {
        assert (0L <= value);
        assert (max >= value);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.BeforeTerminal);
        }
        out.writeBoundedLong(value, 1L + max);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.AfterTerminal);
        }
    }

    protected void writeZeroBranchSize(BitOutputStream out, long value, long max, Bits bits) throws IOException {
        assert (0L <= value);
        assert (max >= value);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.BeforeCount);
        }
        if (this.useBinomials) {
            Gaussian.fromBinomial(0.5, max).encode(out, value, max);
        } else {
            out.writeBoundedLong(value, 1L + max);
        }
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.AfterCount);
        }
    }

    public static class BranchCounts {
        public Bits path;
        public long size;
        public long terminals;
        public long zeroCount;
        public long oneCount;

        public BranchCounts(Bits path, long size) {
            this.path = path;
            this.size = size;
        }

        public BranchCounts(Bits path, long size, long terminals, long zeroCount, long oneCount) {
            this.path = path;
            this.size = size;
            this.terminals = terminals;
            this.zeroCount = zeroCount;
            this.oneCount = oneCount;
        }
    }

    public static enum SerializationChecks {
        StartTree,
        EndTree,
        BeforeCount,
        AfterCount,
        BeforeTerminal,
        AfterTerminal;

    }
}

