/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.graphchi.apps;

import edu.cmu.graphchi.ChiEdge;
import edu.cmu.graphchi.ChiFilenames;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.ChiVertex;
import edu.cmu.graphchi.GraphChiContext;
import edu.cmu.graphchi.GraphChiProgram;
import edu.cmu.graphchi.datablocks.FloatConverter;
import edu.cmu.graphchi.datablocks.IntConverter;
import edu.cmu.graphchi.engine.GraphChiEngine;
import edu.cmu.graphchi.engine.VertexInterval;
import edu.cmu.graphchi.preprocessing.EdgeProcessor;
import edu.cmu.graphchi.preprocessing.FastSharder;
import edu.cmu.graphchi.preprocessing.VertexIdTranslate;
import edu.cmu.graphchi.util.FileUtils;
import edu.cmu.graphchi.util.HugeDoubleMatrix;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.logging.Logger;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.CholeskyDecompositionImpl;
import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;

public class ALSMatrixFactorization
implements GraphChiProgram<Integer, Float> {
    protected static Logger logger = ChiLogger.getLogger("ALS");
    protected HugeDoubleMatrix vertexValueMatrix;
    protected int D;
    protected String baseFilename;
    protected int numShards;
    protected double LAMBDA = 0.065;
    protected double rmse = 0.0;

    protected ALSMatrixFactorization(int n, String string, int n2) {
        this.D = n;
        this.numShards = n2;
        this.baseFilename = string;
    }

    public double predict(int n, int n2) {
        double[] dArray = new double[this.D];
        double[] dArray2 = new double[this.D];
        this.vertexValueMatrix.getRow(n, dArray);
        this.vertexValueMatrix.getRow(n2, dArray2);
        return new ArrayRealVector(dArray).dotProduct(new ArrayRealVector(dArray2));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(ChiVertex<Integer, Float> chiVertex, GraphChiContext graphChiContext) {
        block16: {
            if (chiVertex.numEdges() == 0) {
                return;
            }
            BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.D, this.D);
            ArrayRealVector arrayRealVector = new ArrayRealVector(this.D);
            try {
                int n;
                int n2;
                double[] dArray = new double[this.D];
                for (n2 = 0; n2 < chiVertex.numEdges(); ++n2) {
                    ChiEdge<Float> chiEdge = chiVertex.edge(n2);
                    float f = chiEdge.getValue().floatValue();
                    this.vertexValueMatrix.getRow(chiEdge.getVertexId(), dArray);
                    for (n = 0; n < this.D; ++n) {
                        arrayRealVector.setEntry(n, arrayRealVector.getEntry(n) + dArray[n] * (double)f);
                        for (int i = n; i < this.D; ++i) {
                            blockRealMatrix.setEntry(i, n, blockRealMatrix.getEntry(i, n) + dArray[n] * dArray[i]);
                        }
                    }
                }
                for (n2 = 0; n2 < this.D; ++n2) {
                    for (int i = n2 + 1; i < this.D; ++i) {
                        blockRealMatrix.setEntry(n2, i, blockRealMatrix.getEntry(i, n2));
                    }
                }
                for (n2 = 0; n2 < this.D; ++n2) {
                    blockRealMatrix.setEntry(n2, n2, blockRealMatrix.getEntry(n2, n2) + this.LAMBDA * (double)chiVertex.numEdges());
                }
                RealVector realVector = new CholeskyDecompositionImpl((RealMatrix)blockRealMatrix).getSolver().solve((RealVector)arrayRealVector);
                for (int i = 0; i < this.D; ++i) {
                    this.vertexValueMatrix.setValue(chiVertex.getId(), i, realVector.getEntry(i));
                }
                if (!graphChiContext.isLastIteration() || chiVertex.numInEdges() <= 0) break block16;
                if (chiVertex.numOutEdges() > 0) {
                    throw new IllegalStateException("Not a bipartite graph!");
                }
                double d = 0.0;
                for (n = 0; n < chiVertex.numInEdges(); ++n) {
                    ChiEdge<Float> chiEdge = chiVertex.edge(n);
                    float f = chiEdge.getValue().floatValue();
                    this.vertexValueMatrix.getRow(chiEdge.getVertexId(), dArray);
                    double d2 = new ArrayRealVector(dArray).dotProduct(realVector);
                    d += (d2 - (double)f) * (d2 - (double)f);
                }
                ALSMatrixFactorization aLSMatrixFactorization = this;
                synchronized (aLSMatrixFactorization) {
                    this.rmse += d;
                }
            }
            catch (NotPositiveDefiniteMatrixException notPositiveDefiniteMatrixException) {
                logger.warning("Matrix was not positive definite: " + blockRealMatrix);
            }
            catch (Exception exception) {
                throw new RuntimeException(exception);
            }
        }
    }

    @Override
    public void beginIteration(GraphChiContext graphChiContext) {
        if (graphChiContext.getIteration() == 0) {
            logger.info("Initializing latent factors for " + graphChiContext.getNumVertices() + " vertices");
            this.vertexValueMatrix = new HugeDoubleMatrix(graphChiContext.getNumVertices(), this.D);
            this.vertexValueMatrix.randomize(0.0, 1.0);
        }
    }

    @Override
    public void endIteration(GraphChiContext graphChiContext) {
    }

    @Override
    public void beginInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    public void endInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    public void beginSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    public void endSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    protected static FastSharder createSharder(String string, int n) throws IOException {
        return new FastSharder<Integer, Float>(string, n, null, new EdgeProcessor<Float>(){

            @Override
            public Float receiveEdge(int n, int n2, String string) {
                return Float.valueOf(string == null ? 0.0f : Float.parseFloat(string));
            }
        }, new IntConverter(), new FloatConverter());
    }

    public static void main(String[] stringArray) throws Exception {
        if (stringArray.length < 2) {
            throw new IllegalArgumentException("Usage: java edu.cmu.graphchi.ALSMatrixFactorization <input-file> <nshards> <D>");
        }
        String string = stringArray[0];
        int n = Integer.parseInt(stringArray[1]);
        int n2 = 20;
        if (stringArray.length == 3) {
            n2 = Integer.parseInt(stringArray[2]);
        }
        ALSMatrixFactorization aLSMatrixFactorization = ALSMatrixFactorization.computeALS(string, n, n2, 5);
        aLSMatrixFactorization.writeOutputMatrices();
    }

    public static ALSMatrixFactorization computeALS(String string, int n, int n2, int n3) throws IOException {
        FastSharder fastSharder = ALSMatrixFactorization.createSharder(string, n);
        if (!new File(ChiFilenames.getFilenameIntervals(string, n)).exists() || !new File(string + ".matrixinfo").exists()) {
            fastSharder.shard((InputStream)new FileInputStream(new File(string)), FastSharder.GraphInputFormat.MATRIXMARKET);
        } else {
            logger.info("Found shards -- no need to preprocess");
        }
        ALSMatrixFactorization aLSMatrixFactorization = new ALSMatrixFactorization(n2, string, n);
        logger.info("Set latent factor dimension to: " + aLSMatrixFactorization.D);
        GraphChiEngine graphChiEngine = new GraphChiEngine(string, n);
        graphChiEngine.setEdataConverter(new FloatConverter());
        graphChiEngine.setEnableDeterministicExecution(false);
        graphChiEngine.setVertexDataConverter(null);
        graphChiEngine.setModifiesInedges(false);
        graphChiEngine.setModifiesOutedges(false);
        graphChiEngine.run(aLSMatrixFactorization, n3);
        double d = Math.sqrt(aLSMatrixFactorization.rmse / (1.0 * (double)graphChiEngine.numEdges()));
        logger.info("Train RMSE: " + d + ", total edges:" + graphChiEngine.numEdges());
        return aLSMatrixFactorization;
    }

    public BipartiteGraphInfo getGraphInfo() {
        String string = this.baseFilename + ".matrixinfo";
        try {
            String string2 = FileUtils.readToString(string);
            String[] stringArray = string2.split("\t");
            int n = Integer.parseInt(stringArray[0]);
            int n2 = Integer.parseInt(stringArray[1]);
            return new BipartiteGraphInfo(n, n2);
        }
        catch (IOException iOException) {
            throw new RuntimeException("Could not load matrix info! File: " + string);
        }
    }

    private void writeOutputMatrices() throws Exception {
        int n;
        int n2;
        BipartiteGraphInfo bipartiteGraphInfo = this.getGraphInfo();
        int n3 = bipartiteGraphInfo.getNumLeft();
        int n4 = bipartiteGraphInfo.getNumRight();
        VertexIdTranslate vertexIdTranslate = VertexIdTranslate.fromFile(new File(ChiFilenames.getVertexTranslateDefFile(this.baseFilename, this.numShards)));
        String string = this.baseFilename + "_U.mm";
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(string));
        bufferedWriter.write("%%MatrixMarket matrix array real general\n");
        bufferedWriter.write(this.D + " " + n3 + "\n");
        for (int i = 0; i < n3; ++i) {
            n2 = vertexIdTranslate.forward(i);
            for (n = 0; n < this.D; ++n) {
                bufferedWriter.write(this.vertexValueMatrix.getValue(n2, n) + "\n");
            }
        }
        bufferedWriter.close();
        String string2 = this.baseFilename + "_V.mm";
        bufferedWriter = new BufferedWriter(new FileWriter(string2));
        bufferedWriter.write("%%MatrixMarket matrix array real general\n");
        bufferedWriter.write(this.D + " " + n4 + "\n");
        for (n2 = 0; n2 < n4; ++n2) {
            n = vertexIdTranslate.forward(n3 + n2);
            for (int i = 0; i < this.D; ++i) {
                bufferedWriter.write(this.vertexValueMatrix.getValue(n, i) + "\n");
            }
        }
        bufferedWriter.close();
        logger.info("Latent factor matrices saved: " + this.baseFilename + "_U.mm" + ", " + this.baseFilename + "_V.mm");
    }

    public class BipartiteGraphInfo {
        private int numLeft;
        private int numRight;

        public BipartiteGraphInfo(int n, int n2) {
            this.numLeft = n;
            this.numRight = n2;
        }

        public int getNumLeft() {
            return this.numLeft;
        }

        public int getNumRight() {
            return this.numRight;
        }
    }
}

