/*
 * Decompiled with CFR 0.152.
 */
package wf.core.game_engine.neural_network;

import java.util.Arrays;
import java.util.Random;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;
import wf.core.game_engine.neural_network.WrongDimensionException;
import wf.core.game_engine.neural_network.activationfunctions.ActivationFunction;
import wf.core.game_engine.neural_network.utilities.FileReaderAndWriter;
import wf.core.game_engine.neural_network.utilities.MatrixUtilities;

public class NeuralNetwork {
    private Random random = new Random();
    private int inputNodes;
    private int[] hiddenNodes;
    private int outputNodes;
    private SimpleMatrix[] weights;
    private SimpleMatrix[] biases;
    private double learningRate;
    private ActivationFunction activationFunction;

    public NeuralNetwork(int[] nodes) {
        if (nodes.length < 3) {
            throw new IllegalArgumentException("Less than 3 layers are not allowed");
        }
        this.inputNodes = nodes[0];
        this.hiddenNodes = Arrays.copyOfRange(nodes, 1, nodes.length - 1);
        this.outputNodes = nodes[nodes.length - 1];
        this.initializeDefaultValues();
        this.initializeWeights();
        this.initializeBiases();
    }

    public NeuralNetwork(int inputNodes, int[] hiddenNodes, int outputNodes) {
        this.inputNodes = inputNodes;
        this.hiddenNodes = hiddenNodes;
        this.outputNodes = outputNodes;
        this.initializeDefaultValues();
        this.initializeWeights();
        this.initializeBiases();
    }

    public void clear() {
        this.initializeDefaultValues();
        this.initializeWeights();
        this.initializeBiases();
    }

    public NeuralNetwork(NeuralNetwork nn) {
        int i;
        this.inputNodes = nn.inputNodes;
        this.hiddenNodes = nn.hiddenNodes;
        this.outputNodes = nn.outputNodes;
        this.weights = new SimpleMatrix[this.hiddenNodes.length + 1];
        this.biases = new SimpleMatrix[this.hiddenNodes.length + 1];
        for (i = 0; i < nn.weights.length; ++i) {
            this.weights[i] = (SimpleMatrix)nn.weights[i].copy();
        }
        for (i = 0; i < nn.biases.length; ++i) {
            this.biases[i] = (SimpleMatrix)nn.biases[i].copy();
        }
        this.learningRate = nn.learningRate;
        this.activationFunction = nn.activationFunction;
    }

    private void initializeDefaultValues() {
        this.setLearningRate(0.1);
        this.setActivationFunction(ActivationFunction.SIGMOID);
    }

    private void initializeWeights() {
        this.weights = new SimpleMatrix[this.hiddenNodes.length + 1];
        for (int i = 0; i < this.weights.length; ++i) {
            this.weights[i] = i == 0 ? SimpleMatrix.random64((int)this.hiddenNodes[i], (int)this.inputNodes, (double)-1.0, (double)1.0, (Random)this.random) : (i == this.weights.length - 1 ? SimpleMatrix.random64((int)this.outputNodes, (int)this.hiddenNodes[i - 1], (double)-1.0, (double)1.0, (Random)this.random) : SimpleMatrix.random64((int)this.hiddenNodes[i], (int)this.hiddenNodes[i - 1], (double)-1.0, (double)1.0, (Random)this.random));
        }
    }

    private void initializeBiases() {
        this.biases = new SimpleMatrix[this.hiddenNodes.length + 1];
        for (int i = 0; i < this.biases.length; ++i) {
            this.biases[i] = i == this.biases.length - 1 ? SimpleMatrix.random64((int)this.outputNodes, (int)1, (double)-1.0, (double)1.0, (Random)this.random) : SimpleMatrix.random64((int)this.hiddenNodes[i], (int)1, (double)-1.0, (double)1.0, (Random)this.random);
        }
    }

    public double[] guess(double[] input) {
        if (input.length != this.inputNodes) {
            throw new WrongDimensionException(input.length, this.inputNodes, "Input");
        }
        SimpleMatrix output = MatrixUtilities.arrayToMatrix(input);
        for (int i = 0; i < this.hiddenNodes.length + 1; ++i) {
            output = this.calculateLayer(this.weights[i], this.biases[i], output, this.activationFunction);
        }
        return MatrixUtilities.getColumnFromMatrixAsArray(output, 0);
    }

    public void train(double[] inputArray, double[] targetArray) {
        if (inputArray.length != this.inputNodes) {
            throw new WrongDimensionException(inputArray.length, this.inputNodes, "Input");
        }
        if (targetArray.length != this.outputNodes) {
            throw new WrongDimensionException(targetArray.length, this.outputNodes, "Output");
        }
        SimpleMatrix input = MatrixUtilities.arrayToMatrix(inputArray);
        SimpleMatrix target = MatrixUtilities.arrayToMatrix(targetArray);
        SimpleMatrix[] layers = new SimpleMatrix[this.hiddenNodes.length + 2];
        layers[0] = input;
        for (int j = 1; j < this.hiddenNodes.length + 2; ++j) {
            layers[j] = this.calculateLayer(this.weights[j - 1], this.biases[j - 1], input, this.activationFunction);
            input = layers[j];
        }
        for (int n = this.hiddenNodes.length + 1; n > 0; --n) {
            SimpleMatrix errors = (SimpleMatrix)target.minus((SimpleBase)layers[n]);
            SimpleMatrix gradients = this.calculateGradient(layers[n], errors, this.activationFunction);
            SimpleMatrix deltas = this.calculateDeltas(gradients, layers[n - 1]);
            this.biases[n - 1] = (SimpleMatrix)this.biases[n - 1].plus((SimpleBase)gradients);
            this.weights[n - 1] = (SimpleMatrix)this.weights[n - 1].plus((SimpleBase)deltas);
            SimpleMatrix previousError = (SimpleMatrix)((SimpleMatrix)this.weights[n - 1].transpose()).mult((SimpleBase)errors);
            target = (SimpleMatrix)previousError.plus((SimpleBase)layers[n - 1]);
        }
    }

    public NeuralNetwork copy() {
        return new NeuralNetwork(this);
    }

    public NeuralNetwork merge(NeuralNetwork nn) {
        return this.merge(nn, 0.5);
    }

    public NeuralNetwork merge(NeuralNetwork nn, double probability) {
        int i;
        if (!Arrays.equals(this.getDimensions(), nn.getDimensions())) {
            throw new WrongDimensionException(this.getDimensions(), nn.getDimensions());
        }
        NeuralNetwork result = this.copy();
        for (i = 0; i < result.weights.length; ++i) {
            result.weights[i] = MatrixUtilities.mergeMatrices(this.weights[i], nn.weights[i], probability);
        }
        for (i = 0; i < result.biases.length; ++i) {
            result.biases[i] = MatrixUtilities.mergeMatrices(this.biases[i], nn.biases[i], probability);
        }
        return result;
    }

    public void mutate(double probability) {
        this.applyMutation(this.weights, probability);
        this.applyMutation(this.biases, probability);
    }

    private void applyMutation(SimpleMatrix[] matrices, double probability) {
        for (SimpleMatrix matrix : matrices) {
            for (int j = 0; j < matrix.getNumElements(); ++j) {
                if (!(this.random.nextDouble() < probability)) continue;
                double offset = this.random.nextGaussian() / 2.0;
                matrix.set(j, matrix.get(j) + offset);
            }
        }
    }

    private SimpleMatrix calculateLayer(SimpleMatrix weights, SimpleMatrix bias, SimpleMatrix input, ActivationFunction activationFunction) {
        SimpleMatrix result = (SimpleMatrix)weights.mult((SimpleBase)input);
        result = (SimpleMatrix)result.plus((SimpleBase)bias);
        return this.applyActivationFunction(result, false, activationFunction);
    }

    private SimpleMatrix calculateGradient(SimpleMatrix layer, SimpleMatrix error, ActivationFunction activationFunction) {
        SimpleMatrix gradient = this.applyActivationFunction(layer, true, activationFunction);
        gradient = (SimpleMatrix)gradient.elementMult((SimpleBase)error);
        return (SimpleMatrix)gradient.scale(this.learningRate);
    }

    private SimpleMatrix calculateDeltas(SimpleMatrix gradient, SimpleMatrix layer) {
        return (SimpleMatrix)gradient.mult(layer.transpose());
    }

    private SimpleMatrix applyActivationFunction(SimpleMatrix input, boolean derivative, ActivationFunction activationFunction) {
        return derivative ? activationFunction.applyDerivativeOfActivationFunctionToMatrix(input) : activationFunction.applyActivationFunctionToMatrix(input);
    }

    public void writeToFile() {
        FileReaderAndWriter.writeToFile(this, null);
    }

    public void writeToFile(String fileName) {
        FileReaderAndWriter.writeToFile(this, fileName);
    }

    public static NeuralNetwork readFromFile() {
        return FileReaderAndWriter.readFromFile(null);
    }

    public static NeuralNetwork readFromFile(String fileName) {
        return FileReaderAndWriter.readFromFile(fileName);
    }

    public ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    public int getInputNodes() {
        return this.inputNodes;
    }

    public int getHiddenLayers() {
        return this.hiddenNodes.length;
    }

    public int[] getHiddenNodes() {
        return this.hiddenNodes;
    }

    public int getOutputNodes() {
        return this.outputNodes;
    }

    public SimpleMatrix[] getWeights() {
        return this.weights;
    }

    public void setWeights(SimpleMatrix[] weights) {
        this.weights = weights;
    }

    public SimpleMatrix[] getBiases() {
        return this.biases;
    }

    public void setBiases(SimpleMatrix[] biases) {
        this.biases = biases;
    }

    public int[] getDimensions() {
        return new int[]{this.inputNodes, this.hiddenNodes.length, this.outputNodes};
    }
}

