/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hugegraph.computer.algorithm.centrality.betweenness;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hugegraph.computer.algorithm.centrality.betweenness.BetweennessMessage;
import org.apache.hugegraph.computer.algorithm.centrality.betweenness.BetweennessValue;
import org.apache.hugegraph.computer.core.common.exception.ComputerException;
import org.apache.hugegraph.computer.core.config.Config;
import org.apache.hugegraph.computer.core.graph.edge.Edge;
import org.apache.hugegraph.computer.core.graph.id.Id;
import org.apache.hugegraph.computer.core.graph.value.DoubleValue;
import org.apache.hugegraph.computer.core.graph.value.IdList;
import org.apache.hugegraph.computer.core.graph.value.IdSet;
import org.apache.hugegraph.computer.core.graph.value.Value;
import org.apache.hugegraph.computer.core.graph.vertex.Vertex;
import org.apache.hugegraph.computer.core.worker.Computation;
import org.apache.hugegraph.computer.core.worker.ComputationContext;
import org.apache.hugegraph.util.Log;
import org.slf4j.Logger;

public class BetweennessCentrality
implements Computation<BetweennessMessage> {
    private static final Logger LOG = Log.logger(BetweennessCentrality.class);
    public static final String OPTION_SAMPLE_RATE = "betweenness_centrality.sample_rate";
    private double sampleRate;
    private Map<Id, SeqCount> seqTable;

    public String name() {
        return "betweenness_centrality";
    }

    public String category() {
        return "centrality";
    }

    public void init(Config config) {
        this.sampleRate = config.getDouble(OPTION_SAMPLE_RATE, 1.0);
        if (this.sampleRate <= 0.0 || this.sampleRate > 1.0) {
            throw new ComputerException("The param %s must be in (0.0, 1.0], actual got '%s'", new Object[]{OPTION_SAMPLE_RATE, this.sampleRate});
        }
        this.seqTable = new HashMap<Id, SeqCount>();
    }

    public void close(Config config) {
    }

    public void compute0(ComputationContext context, Vertex vertex) {
        BetweennessValue initialValue = new BetweennessValue(0.0);
        initialValue.arrivedVertices().add(vertex.id());
        vertex.value((Value)initialValue);
        if (vertex.numEdges() == 0) {
            return;
        }
        IdList sequence = new IdList();
        sequence.add((Value.Tvalue)vertex.id());
        context.sendMessageToAllEdges(vertex, (Value)new BetweennessMessage(sequence));
        LOG.info("Finished compute-0 step");
    }

    public void compute(ComputationContext context, Vertex vertex, Iterator<BetweennessMessage> messages) {
        boolean active;
        BetweennessValue value = (BetweennessValue)vertex.value();
        DoubleValue betweenness = value.betweenness();
        IdSet arrivingVertices = new IdSet();
        while (messages.hasNext()) {
            BetweennessMessage message = messages.next();
            DoubleValue vote = message.vote();
            betweenness.value(betweenness.value() + vote.value());
            this.forward(context, vertex, message.sequence(), arrivingVertices);
        }
        value.arrivedVertices().addAll(arrivingVertices);
        boolean bl = active = !this.seqTable.isEmpty();
        if (active) {
            this.sendMessage(context);
            this.seqTable.clear();
        } else {
            vertex.inactivate();
        }
    }

    private void forward(ComputationContext context, Vertex vertex, IdList sequence, IdSet arrivingVertices) {
        Id source;
        if (sequence.size() == 0) {
            return;
        }
        BetweennessValue value = (BetweennessValue)vertex.value();
        IdSet arrivedVertices = value.arrivedVertices();
        if (!arrivedVertices.contains(source = (Id)sequence.getFirst())) {
            arrivingVertices.add(source);
            SeqCount seqCount = this.seqTable.computeIfAbsent(source, k -> new SeqCount());
            ++seqCount.totalCount;
            for (int i = 1; i < sequence.size(); ++i) {
                Id id = (Id)sequence.get(i);
                Map<Id, Integer> idCounts = seqCount.idCount;
                idCounts.put(id, idCounts.getOrDefault(id, 0) + 1);
            }
            Id selfId = vertex.id();
            sequence.add((Value.Tvalue)selfId);
            BetweennessMessage newMessage = new BetweennessMessage(sequence);
            for (Edge edge : vertex.edges()) {
                Id targetId = edge.targetId();
                if (!this.sample(selfId, targetId, edge) || sequence.contains((Value.Tvalue)targetId)) continue;
                context.sendMessage(targetId, (Value)newMessage);
            }
        }
    }

    private void sendMessage(ComputationContext context) {
        for (SeqCount seqCount : this.seqTable.values()) {
            for (Map.Entry<Id, Integer> entry : seqCount.idCount.entrySet()) {
                double vote = (double)entry.getValue().intValue() / (double)seqCount.totalCount;
                BetweennessMessage voteMessage = new BetweennessMessage(new DoubleValue(vote));
                context.sendMessage(entry.getKey(), (Value)voteMessage);
            }
        }
    }

    private boolean sample(Id sourceId, Id targetId, Edge edge) {
        return Math.random() <= this.sampleRate;
    }

    private static class SeqCount {
        private final Map<Id, Integer> idCount = new HashMap<Id, Integer>();
        private int totalCount = 0;
    }
}

