/*
 * Copyright (c) 2017 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.util.test;

import com.simiacryptus.util.Util;
import com.simiacryptus.util.io.BinaryChunkIterator;
import com.simiacryptus.util.io.DataLoader;
import com.simiacryptus.util.ml.Tensor;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.input.BoundedInputStream;

import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;

public class CIFAR10 {
  
  private final static URI source = URI.create("https://www.cs.toronto.edu/~kriz/");
  
  private static final DataLoader training = new DataLoader<LabeledObject<Tensor>>() {
    @Override
    protected void read(List<LabeledObject<Tensor>> queue) {
      try {
        InputStream stream = null;
        try {
          stream = Util.cache(source.resolve("cifar-10-binary.tar.gz"));
        } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException e) {
          throw new RuntimeException(e);
        }
        int recordSize = 3073;
        GZIPInputStream inflatedInput = new GZIPInputStream(stream);
        TarArchiveInputStream tar = new TarArchiveInputStream(inflatedInput);
        while (0 < inflatedInput.available()) {
          if(Thread.interrupted()) break;
          TarArchiveEntry nextTarEntry = tar.getNextTarEntry();
          if(null==nextTarEntry) break;
          BinaryChunkIterator iterator = new BinaryChunkIterator(new DataInputStream(new BoundedInputStream(tar, nextTarEntry.getSize())), recordSize);
          for(byte[] chunk : (Iterable<byte[]>) () -> iterator) {
            queue.add(toImage(chunk).map(img->Tensor.fromRGB(img)));
          }
        }
        System.err.println("Done loading");
      } catch (IOException e) {
        e.printStackTrace();
        throw new RuntimeException(e);
      }
    }
  };
  
  public static Stream<LabeledObject<Tensor>> trainingDataStream() throws IOException {
    return training.stream();
  }
  
  public static void  halt() {
    training.stop();
  }
  
  private static LabeledObject<BufferedImage> toImage(final byte[] b) {
    final BufferedImage img = new BufferedImage(32, 32, BufferedImage.TYPE_INT_RGB);
    for (int x = 0; x < img.getWidth(); x++) {
      for (int y = 0; y < img.getHeight(); y++) {
        final int red = 0xFF & b[1 + 1024 * 0 + x + y * 32];
        final int blue = 0xFF & b[1 + 1024 * 1 + x + y * 32];
        final int green = 0xFF & b[1 + 1024 * 2 + x + y * 32];
        final int c = (red << 16) + (blue << 8) + green;
        img.setRGB(x, y, c);
      }
    }
    return new LabeledObject<BufferedImage>(img, Arrays.toString(new byte[]{b[0]}));
  }
  
  
}
