/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.binary.bitset;

import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefCollection;
import com.simiacryptus.ref.wrappers.RefHashSet;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefMaps;
import com.simiacryptus.ref.wrappers.RefNavigableMap;
import com.simiacryptus.ref.wrappers.RefTreeMap;
import com.simiacryptus.util.binary.BitInputStream;
import com.simiacryptus.util.binary.BitOutputStream;
import com.simiacryptus.util.binary.Bits;
import com.simiacryptus.util.binary.bitset.BitsCollection;
import com.simiacryptus.util.binary.codes.Gaussian;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class CountTreeBitsCollection
extends BitsCollection<RefTreeMap<Bits, AtomicInteger>> {
    public static boolean SERIALIZATION_CHECKS = false;
    private boolean useBinomials = true;

    public CountTreeBitsCollection() {
        super(new RefTreeMap());
    }

    public CountTreeBitsCollection(@Nonnull BitInputStream bitStream) throws IOException {
        this();
        this.read(bitStream);
    }

    public CountTreeBitsCollection(@Nonnull BitInputStream bitStream, int bitDepth) throws IOException {
        this(bitDepth);
        this.read(bitStream);
    }

    public CountTreeBitsCollection(@Nonnull byte[] data) throws IOException {
        this(BitInputStream.toBitStream(data));
    }

    public CountTreeBitsCollection(@Nonnull byte[] data, int bitDepth) throws IOException {
        this(BitInputStream.toBitStream(data), bitDepth);
    }

    public CountTreeBitsCollection(int bitDepth) {
        super(bitDepth, new RefTreeMap());
    }

    public void setUseBinomials(boolean useBinomials) {
        this.useBinomials = useBinomials;
    }

    public static <T> T isNull(@Nullable T value, T defaultValue) {
        return null == value ? defaultValue : value;
    }

    public @Nonnull RefTreeMap<Bits, Long> computeSums() {
        RefTreeMap sums = new RefTreeMap();
        AtomicLong total = new AtomicLong();
        assert (this.map != null);
        RefHashSet entries = ((RefTreeMap)this.map).entrySet();
        entries.forEach(e -> {
            RefUtil.freeRef((Object)sums.put(e.getKey(), (Object)total.addAndGet(((AtomicInteger)e.getValue()).get())));
            RefUtil.freeRef((Object)e);
        });
        entries.freeRef();
        return sums;
    }

    @Override
    public void read(@Nonnull BitInputStream in) throws IOException {
        RefMap temp_13_0001 = this.getMap();
        temp_13_0001.clear();
        temp_13_0001.freeRef();
        long size = in.readVarLong();
        if (0L < size) {
            this.read(in, Bits.NULL, size);
        }
    }

    public void read(@Nonnull BitInputStream in, int size) throws IOException {
        RefMap temp_13_0002 = this.getMap();
        temp_13_0002.clear();
        temp_13_0002.freeRef();
        if (0 < size) {
            this.read(in, Bits.NULL, size);
        }
    }

    public long sum(@RefAware @Nonnull RefCollection<Long> values) {
        long total = values.stream().mapToLong(v -> v).sum();
        values.freeRef();
        return total;
    }

    public @Nonnull byte[] toBytes() throws IOException {
        ByteArrayOutputStream outBuffer = new ByteArrayOutputStream();
        BitOutputStream out = new BitOutputStream(outBuffer);
        this.write(out);
        out.flush();
        return outBuffer.toByteArray();
    }

    public boolean useBinomials() {
        return this.useBinomials;
    }

    @Override
    public void write(@Nonnull BitOutputStream out) throws IOException {
        RefTreeMap<Bits, Long> sums = this.computeSums();
        Map.Entry temp_13_0003 = sums.lastEntry();
        long value = 0 == sums.size() ? 0L : (Long)temp_13_0003.getValue();
        RefUtil.freeRef((Object)temp_13_0003);
        out.writeVarLong(value);
        if (0L < value) {
            this.write(out, Bits.NULL, (RefNavigableMap<Bits, Long>)((RefNavigableMap)RefUtil.addRef(sums)));
        }
        sums.freeRef();
    }

    public void write(@Nonnull BitOutputStream out, int size) throws IOException {
        RefTreeMap<Bits, Long> sums = this.computeSums();
        Map.Entry temp_13_0004 = sums.lastEntry();
        long value = 0 == sums.size() ? 0L : (Long)temp_13_0004.getValue();
        RefUtil.freeRef((Object)temp_13_0004);
        if (value != (long)size) {
            sums.freeRef();
            throw new RuntimeException();
        }
        if (0L < value) {
            this.write(out, Bits.NULL, (RefNavigableMap<Bits, Long>)((RefNavigableMap)RefUtil.addRef(sums)));
        }
        sums.freeRef();
    }

    @Override
    public void _free() {
        super._free();
    }

    @Override
    public @Nonnull CountTreeBitsCollection addRef() {
        return (CountTreeBitsCollection)super.addRef();
    }

    protected @Nonnull BranchCounts readBranchCounts(@Nonnull BitInputStream in, @Nonnull Bits code, long size) throws IOException {
        BranchCounts branchCounts = new BranchCounts(code, size);
        BitsCollection.CodeType currentCodeType = this.getType(code);
        long maximum = size;
        branchCounts.terminals = currentCodeType == BitsCollection.CodeType.Unknown ? this.readTerminalCount(in, maximum) : (currentCodeType == BitsCollection.CodeType.Terminal ? size : 0L);
        if ((maximum -= branchCounts.terminals) > 0L) {
            assert (Thread.currentThread().getStackTrace().length < 100);
            branchCounts.zeroCount = this.readZeroBranchSize(in, maximum);
        }
        branchCounts.oneCount = maximum -= branchCounts.zeroCount;
        return branchCounts;
    }

    protected long readTerminalCount(@Nonnull BitInputStream in, long size) throws IOException {
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.BeforeTerminal);
        }
        long readBoundedLong = in.readBoundedLong(1L + size);
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.AfterTerminal);
        }
        return readBoundedLong;
    }

    protected long readZeroBranchSize(@Nonnull BitInputStream in, long max) throws IOException {
        if (0L == max) {
            return 0L;
        }
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.BeforeCount);
        }
        long value = this.useBinomials ? Gaussian.fromBinomial(0.5, max).decode(in, max) : in.readBoundedLong(1L + max);
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.AfterCount);
        }
        return value;
    }

    protected void writeBranchCounts(@Nonnull BranchCounts branch, @Nonnull BitOutputStream out) throws IOException {
        BitsCollection.CodeType currentCodeType = this.getType(branch.path);
        long maximum = branch.size;
        assert (maximum >= branch.terminals);
        if (currentCodeType == BitsCollection.CodeType.Unknown) {
            this.writeTerminalCount(out, branch.terminals, maximum);
        } else if (currentCodeType == BitsCollection.CodeType.Terminal) {
            assert (branch.size == branch.terminals);
            assert (0L == branch.zeroCount);
            assert (0L == branch.oneCount);
        } else assert (currentCodeType != BitsCollection.CodeType.Prefix || 0L == branch.terminals);
        assert ((maximum -= branch.terminals) >= branch.zeroCount);
        if (0L < maximum) {
            this.writeZeroBranchSize(out, branch.zeroCount, maximum);
            maximum -= branch.zeroCount;
        } else assert (0L == branch.zeroCount);
        assert (maximum == branch.oneCount);
    }

    protected void writeTerminalCount(@Nonnull BitOutputStream out, long value, long max) throws IOException {
        assert (0L <= value);
        assert (max >= value);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.BeforeTerminal);
        }
        out.writeBoundedLong(value, 1L + max);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.AfterTerminal);
        }
    }

    protected void writeZeroBranchSize(@Nonnull BitOutputStream out, long value, long max) throws IOException {
        assert (0L <= value);
        assert (max >= value);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.BeforeCount);
        }
        if (this.useBinomials) {
            Gaussian.fromBinomial(0.5, max).encode(out, value, max);
        } else {
            out.writeBoundedLong(value, 1L + max);
        }
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.AfterCount);
        }
    }

    private void read(@Nonnull BitInputStream in, @Nonnull Bits code, long size) throws IOException {
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.StartTree);
        }
        BranchCounts branchCounts = this.readBranchCounts(in, code, size);
        if (0L < branchCounts.terminals) {
            assert (this.map != null);
            RefUtil.freeRef((Object)((RefTreeMap)this.map).put((Object)code, (Object)new AtomicInteger((int)branchCounts.terminals)));
        }
        if (0L < branchCounts.zeroCount) {
            this.read(in, code.concatenate(Bits.ZERO), branchCounts.zeroCount);
        }
        if (branchCounts.oneCount > 0L) {
            this.read(in, code.concatenate(Bits.ONE), branchCounts.oneCount);
        }
        if (SERIALIZATION_CHECKS) {
            in.expect(SerializationChecks.EndTree);
        }
    }

    private void write(@Nonnull BitOutputStream out, @Nonnull Bits currentCode, @RefAware @Nonnull RefNavigableMap<Bits, Long> sums) throws IOException {
        long zeroCount;
        Map.Entry firstEntry = sums.firstEntry();
        RefNavigableMap remainder = sums.tailMap((Object)currentCode, false);
        Bits splitCode = currentCode.concatenate(Bits.ONE);
        RefNavigableMap zeroMap = remainder.headMap((Object)splitCode, false);
        RefNavigableMap oneMap = remainder.tailMap((Object)splitCode, true);
        remainder.freeRef();
        assert (this.map != null);
        AtomicInteger atomicInteger = (AtomicInteger)((RefTreeMap)this.map).get(firstEntry.getKey());
        int firstEntryCount = atomicInteger.get();
        long baseCount = (Long)firstEntry.getValue() - (long)firstEntryCount;
        Map.Entry temp_13_0005 = sums.lastEntry();
        long endCount = (Long)temp_13_0005.getValue();
        RefUtil.freeRef((Object)temp_13_0005);
        long size = endCount - baseCount;
        long terminals = ((Bits)firstEntry.getKey()).equals(currentCode) ? (long)firstEntryCount : 0L;
        RefUtil.freeRef((Object)firstEntry);
        Map.Entry temp_13_0006 = zeroMap.lastEntry();
        long l = zeroCount = 0 == zeroMap.size() ? 0L : (Long)temp_13_0006.getValue() - baseCount - terminals;
        if (null != temp_13_0006) {
            RefUtil.freeRef((Object)temp_13_0006);
        }
        long oneCount = size - terminals - zeroCount;
        RefMaps.EntryTransformer<Bits, Long, Long> transformer = new RefMaps.EntryTransformer<Bits, Long, Long>(){

            public @Nonnull Long transformEntry(@RefAware Bits key, Long value) {
                AtomicInteger atomicInteger = (AtomicInteger)((RefTreeMap)CountTreeBitsCollection.this.map).get((Object)key);
                return atomicInteger.get();
            }
        };
        RefMap temp_13_0007 = RefMaps.transformEntries((RefMap)((RefMap)RefUtil.addRef(sums)), (RefMaps.EntryTransformer)transformer);
        assert (size == this.sum((RefCollection<Long>)temp_13_0007.values()));
        temp_13_0007.freeRef();
        sums.freeRef();
        RefMap temp_13_0008 = RefMaps.transformEntries((RefMap)((RefMap)RefUtil.addRef((Object)zeroMap)), (RefMaps.EntryTransformer)transformer);
        assert (zeroCount == this.sum((RefCollection<Long>)temp_13_0008.values()));
        temp_13_0008.freeRef();
        RefMap temp_13_0009 = RefMaps.transformEntries((RefMap)((RefMap)RefUtil.addRef((Object)oneMap)), (RefMaps.EntryTransformer)transformer);
        assert (oneCount == this.sum((RefCollection<Long>)temp_13_0009.values()));
        temp_13_0009.freeRef();
        BranchCounts branchCounts = new BranchCounts(currentCode, size, terminals, zeroCount, oneCount);
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.StartTree);
        }
        this.writeBranchCounts(branchCounts, out);
        if (0L < zeroCount) {
            this.write(out, currentCode.concatenate(Bits.ZERO), (RefNavigableMap<Bits, Long>)((RefNavigableMap)RefUtil.addRef((Object)zeroMap)));
        }
        zeroMap.freeRef();
        if (0L < oneCount) {
            this.write(out, currentCode.concatenate(Bits.ONE), (RefNavigableMap<Bits, Long>)((RefNavigableMap)RefUtil.addRef((Object)oneMap)));
        }
        oneMap.freeRef();
        if (SERIALIZATION_CHECKS) {
            out.write(SerializationChecks.EndTree);
        }
    }

    public static class BranchCounts {
        public Bits path;
        public long size;
        public long terminals;
        public long zeroCount;
        public long oneCount;

        public BranchCounts(Bits path, long size) {
            this.path = path;
            this.size = size;
        }

        public BranchCounts(Bits path, long size, long terminals, long zeroCount, long oneCount) {
            this.path = path;
            this.size = size;
            this.terminals = terminals;
            this.zeroCount = zeroCount;
            this.oneCount = oneCount;
        }
    }

    public static enum SerializationChecks {
        StartTree,
        EndTree,
        BeforeCount,
        AfterCount,
        BeforeTerminal,
        AfterTerminal;

    }
}

