/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.graphchi.apps.pig;

import edu.cmu.graphchi.ChiEdge;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.ChiVertex;
import edu.cmu.graphchi.GraphChiContext;
import edu.cmu.graphchi.GraphChiProgram;
import edu.cmu.graphchi.datablocks.FloatConverter;
import edu.cmu.graphchi.datablocks.IntConverter;
import edu.cmu.graphchi.engine.GraphChiEngine;
import edu.cmu.graphchi.engine.VertexInterval;
import edu.cmu.graphchi.hadoop.PigGraphChiBase;
import edu.cmu.graphchi.preprocessing.EdgeProcessor;
import edu.cmu.graphchi.preprocessing.FastSharder;
import edu.cmu.graphchi.preprocessing.VertexIdTranslate;
import edu.cmu.graphchi.util.HugeDoubleMatrix;
import java.io.IOException;
import java.util.logging.Logger;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.CholeskyDecompositionImpl;
import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;

public class PigALSMatrixFactorization
extends PigGraphChiBase
implements GraphChiProgram<Integer, Float> {
    private static Logger logger = ChiLogger.getLogger("ALS");
    private HugeDoubleMatrix leftSideMatrix;
    private HugeDoubleMatrix rightSideMatrix;
    private int D = 5;
    double LAMBDA = 0.065;
    double rmse = 0.0;
    private static final int LEFTSIDE = 0;
    private static final int RIGHTSIDE = 1;
    private int maxLeftVertexId = 0;
    private int maxRightVertexId = 0;
    private int outputCounter = 0;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(ChiVertex<Integer, Float> chiVertex, GraphChiContext graphChiContext) {
        if (chiVertex.numEdges() == 0) {
            return;
        }
        VertexIdTranslate vertexIdTranslate = graphChiContext.getVertexIdTranslate();
        for (int i = 0; i <= 1; ++i) {
            HugeDoubleMatrix hugeDoubleMatrix;
            HugeDoubleMatrix hugeDoubleMatrix2 = i == 0 ? this.leftSideMatrix : this.rightSideMatrix;
            HugeDoubleMatrix hugeDoubleMatrix3 = hugeDoubleMatrix = i == 0 ? this.rightSideMatrix : this.leftSideMatrix;
            if (i == 0 && chiVertex.numOutEdges() == 0 || i == 1 && chiVertex.numInEdges() == 0) continue;
            BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.D, this.D);
            ArrayRealVector arrayRealVector = new ArrayRealVector(this.D);
            try {
                int n;
                int n2;
                double[] dArray = new double[this.D];
                int n3 = i == 0 ? chiVertex.numOutEdges() : chiVertex.numInEdges();
                for (n2 = 0; n2 < n3; ++n2) {
                    ChiEdge<Float> chiEdge = i == 0 ? chiVertex.outEdge(n2) : chiVertex.inEdge(n2);
                    float f = chiEdge.getValue().floatValue();
                    if ((double)f < 1.0) {
                        throw new RuntimeException("Had invalid observation: " + f + " on edge " + vertexIdTranslate.backward(chiVertex.getId()) + "->" + vertexIdTranslate.backward(chiEdge.getVertexId()));
                    }
                    hugeDoubleMatrix.getRow(vertexIdTranslate.backward(chiEdge.getVertexId()), dArray);
                    for (n = 0; n < this.D; ++n) {
                        arrayRealVector.setEntry(n, arrayRealVector.getEntry(n) + dArray[n] * (double)f);
                        for (int j = n; j < this.D; ++j) {
                            blockRealMatrix.setEntry(j, n, blockRealMatrix.getEntry(j, n) + dArray[n] * dArray[j]);
                        }
                    }
                }
                for (n2 = 0; n2 < this.D; ++n2) {
                    for (int j = n2 + 1; j < this.D; ++j) {
                        blockRealMatrix.setEntry(n2, j, blockRealMatrix.getEntry(j, n2));
                    }
                }
                for (n2 = 0; n2 < this.D; ++n2) {
                    blockRealMatrix.setEntry(n2, n2, blockRealMatrix.getEntry(n2, n2) + this.LAMBDA * (double)chiVertex.numEdges());
                }
                RealVector realVector = new CholeskyDecompositionImpl((RealMatrix)blockRealMatrix).getSolver().solve((RealVector)arrayRealVector);
                for (int j = 0; j < this.D; ++j) {
                    hugeDoubleMatrix2.setValue(vertexIdTranslate.backward(chiVertex.getId()), j, realVector.getEntry(j));
                }
                if (!graphChiContext.isLastIteration() || i != 1 || chiVertex.numInEdges() <= 0) continue;
                double d = 0.0;
                for (n = 0; n < chiVertex.numInEdges(); ++n) {
                    ChiEdge<Float> chiEdge = chiVertex.inEdge(n);
                    float f = chiEdge.getValue().floatValue();
                    hugeDoubleMatrix.getRow(vertexIdTranslate.backward(chiEdge.getVertexId()), dArray);
                    double d2 = new ArrayRealVector(dArray).dotProduct(realVector);
                    d += (d2 - (double)f) * (d2 - (double)f);
                }
                PigALSMatrixFactorization pigALSMatrixFactorization = this;
                synchronized (pigALSMatrixFactorization) {
                    this.rmse += d;
                    continue;
                }
            }
            catch (NotPositiveDefiniteMatrixException notPositiveDefiniteMatrixException) {
                logger.warning("Matrix was not positive definite: " + blockRealMatrix);
                continue;
            }
            catch (Exception exception) {
                exception.printStackTrace();
                throw new RuntimeException(exception);
            }
        }
    }

    @Override
    public void beginIteration(GraphChiContext graphChiContext) {
        if (graphChiContext.getIteration() == 0) {
            logger.info("Initializing latent factors for " + (1 + this.maxLeftVertexId) + " vertices on the left side");
            logger.info("Initializing latent factors for " + (1 + this.maxRightVertexId) + " vertices on the right side");
            this.leftSideMatrix = new HugeDoubleMatrix(this.maxLeftVertexId + 1, this.D);
            this.rightSideMatrix = new HugeDoubleMatrix(this.maxRightVertexId + 1, this.D);
            this.leftSideMatrix.randomize(0.0, 1.0);
            this.rightSideMatrix.randomize(0.0, 1.0);
        }
    }

    @Override
    public void endIteration(GraphChiContext graphChiContext) {
    }

    @Override
    public void beginInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    public void endInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    public void beginSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    public void endSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override
    protected FastSharder createSharder(String string, int n) throws IOException {
        return new FastSharder<Integer, Float>(string, n, null, new EdgeProcessor<Float>(){

            @Override
            public Float receiveEdge(int n, int n2, String string) {
                PigALSMatrixFactorization.this.maxLeftVertexId = Math.max(n, PigALSMatrixFactorization.this.maxLeftVertexId);
                PigALSMatrixFactorization.this.maxRightVertexId = Math.max(n2, PigALSMatrixFactorization.this.maxRightVertexId);
                return Float.valueOf(string == null ? 0.0f : Float.parseFloat(string));
            }
        }, new IntConverter(), new FloatConverter());
    }

    @Override
    protected String getSchemaString() {
        String string = "{factor:string,id:int";
        for (int i = 0; i < this.D; ++i) {
            string = string + ",x" + i;
        }
        string = string + "}";
        return string;
    }

    @Override
    protected int getNumShards() {
        return 20;
    }

    @Override
    protected void runGraphChi() throws Exception {
        GraphChiEngine graphChiEngine = new GraphChiEngine(this.getGraphName(), this.getNumShards());
        graphChiEngine.setEdataConverter(new FloatConverter());
        graphChiEngine.setEnableDeterministicExecution(false);
        graphChiEngine.setVertexDataConverter(null);
        graphChiEngine.setModifiesInedges(false);
        graphChiEngine.setModifiesOutedges(false);
        graphChiEngine.run(this, 5);
        double d = Math.sqrt(this.rmse / (1.0 * (double)graphChiEngine.numEdges()));
        logger.info("Train RMSE: " + d + ", total edges:" + graphChiEngine.numEdges());
    }

    @Override
    protected Tuple getNextResult(TupleFactory tupleFactory) throws ExecException {
        String string;
        HugeDoubleMatrix hugeDoubleMatrix;
        int n = 0;
        if (this.outputCounter < this.maxLeftVertexId) {
            hugeDoubleMatrix = this.leftSideMatrix;
            n = this.outputCounter;
            string = "U";
        } else {
            hugeDoubleMatrix = this.rightSideMatrix;
            n = this.outputCounter - this.maxLeftVertexId;
            string = "V";
            if ((long)n >= this.rightSideMatrix.getNumRows()) {
                return null;
            }
        }
        Tuple tuple = tupleFactory.newTuple(2 + this.D);
        tuple.set(0, (Object)string);
        tuple.set(1, (Object)n);
        for (int i = 0; i < this.D; ++i) {
            tuple.set(2 + i, (Object)hugeDoubleMatrix.getValue(n, i));
        }
        ++this.outputCounter;
        return tuple;
    }
}

