/*
 * Decompiled with CFR 0.152.
 */
package com.simiacryptus.util.test;

import com.simiacryptus.util.ml.Tensor;
import com.simiacryptus.util.test.BinaryChunkIterator;
import com.simiacryptus.util.test.LabeledObject;
import com.simiacryptus.util.test.Spool;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Spliterators;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import java.util.zip.GZIPInputStream;
import org.apache.commons.io.IOUtils;

public class MNIST {
    private static final URI source = URI.create("http://yann.lecun.com/exdb/mnist/");

    public static Stream<byte[]> binaryStream(String name, int skip, int recordSize) throws IOException {
        InputStream stream = null;
        try {
            stream = Spool.load(source.resolve(name));
        }
        catch (KeyManagementException | KeyStoreException | NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
        byte[] fileData = IOUtils.toByteArray((InputStream)new BufferedInputStream(new GZIPInputStream(new BufferedInputStream(stream))));
        DataInputStream in = new DataInputStream(new ByteArrayInputStream(fileData));
        in.skip(skip);
        return MNIST.toIterator(new BinaryChunkIterator(in, recordSize));
    }

    public static <T> Stream<T> toIterator(Iterator<T> iterator) {
        return StreamSupport.stream(Spliterators.spliterator(iterator, 1L, 16), false);
    }

    public static Stream<LabeledObject<Tensor>> trainingDataStream() throws IOException {
        final Stream<Tensor> imgStream = MNIST.binaryStream("train-images-idx3-ubyte.gz", 16, 784).map(b -> MNIST.fillImage(b, new Tensor(28, 28, 1)));
        final Stream<byte[]> labelStream = MNIST.binaryStream("train-labels-idx1-ubyte.gz", 8, 1);
        Stream<LabeledObject<Tensor>> merged = MNIST.toStream(new Iterator<LabeledObject<Tensor>>(){
            Iterator<Tensor> imgItr;
            Iterator<byte[]> labelItr;
            {
                this.imgItr = imgStream.iterator();
                this.labelItr = labelStream.iterator();
            }

            @Override
            public boolean hasNext() {
                return this.imgItr.hasNext() && this.labelItr.hasNext();
            }

            @Override
            public LabeledObject<Tensor> next() {
                return new LabeledObject<Tensor>(this.imgItr.next(), Arrays.toString(this.labelItr.next()));
            }
        }, 100);
        return merged;
    }

    public static Tensor fillImage(byte[] b, Tensor tensor) {
        for (int x = 0; x < 28; ++x) {
            for (int y = 0; y < 28; ++y) {
                tensor.set(new int[]{x, y}, (double)(b[x + y * 28] & 0xFF));
            }
        }
        return tensor;
    }

    public static <T> Stream<T> toStream(Iterator<T> iterator, int size) {
        return MNIST.toStream(iterator, size, false);
    }

    public static <T> Stream<T> toStream(Iterator<T> iterator, int size, boolean parallel) {
        return StreamSupport.stream(Spliterators.spliterator(iterator, (long)size, 16), parallel);
    }

    public static Stream<LabeledObject<Tensor>> validationDataStream() throws IOException {
        final Stream<Tensor> imgStream = MNIST.binaryStream("t10k-images-idx3-ubyte.gz", 16, 784).map(b -> MNIST.fillImage(b, new Tensor(28, 28, 1)));
        final Stream<byte[]> labelStream = MNIST.binaryStream("t10k-labels-idx1-ubyte.gz", 8, 1);
        Stream<LabeledObject<Tensor>> merged = MNIST.toStream(new Iterator<LabeledObject<Tensor>>(){
            Iterator<Tensor> imgItr;
            Iterator<byte[]> labelItr;
            {
                this.imgItr = imgStream.iterator();
                this.labelItr = labelStream.iterator();
            }

            @Override
            public boolean hasNext() {
                return this.imgItr.hasNext() && this.labelItr.hasNext();
            }

            @Override
            public LabeledObject<Tensor> next() {
                return new LabeledObject<Tensor>(this.imgItr.next(), Arrays.toString(this.labelItr.next()));
            }
        }, 100);
        return merged;
    }
}

