/*
 * Decompiled with CFR 0.152.
 */
package com.mwt.explorers;

import com.mwt.consumers.ConsumeScorer;
import com.mwt.explorers.Explorer;
import com.mwt.misc.DecisionTuple;
import com.mwt.scorers.Scorer;
import com.mwt.utilities.PRG;
import java.util.ArrayList;

public class SoftmaxExplorer<T>
implements Explorer<T>,
ConsumeScorer<T> {
    private Scorer<T> defaultScorer;
    private boolean explore = true;
    private final float lambda;
    private final int numActions;

    public SoftmaxExplorer(Scorer<T> defaultScorer, float lambda, int numActions) {
        if (numActions < 1) {
            throw new IllegalArgumentException("Number of actions must be at least 1.");
        }
        this.defaultScorer = defaultScorer;
        this.lambda = lambda;
        this.numActions = numActions;
    }

    protected int getNumActions(T context) {
        return this.numActions;
    }

    @Override
    public void updateScorer(Scorer<T> newScorer) {
        this.defaultScorer = newScorer;
    }

    @Override
    public DecisionTuple chooseAction(long saltedSeed, T context) {
        PRG random = new PRG(saltedSeed);
        ArrayList<Float> scores = this.defaultScorer.scoreActions(context);
        int numScores = scores.size();
        if (numScores != this.getNumActions(context)) {
            throw new RuntimeException("The number of scores returned by the scorer must equal number of actions");
        }
        int actionIndex = 0;
        float actionProbability = 1.0f;
        Float maxScore = Float.valueOf(Float.MIN_VALUE);
        for (int i = 0; i < numScores; ++i) {
            if (!(maxScore.floatValue() < scores.get(i).floatValue())) continue;
            maxScore = scores.get(i);
            actionIndex = i;
        }
        if (this.explore) {
            float[] newScores = new float[numScores];
            for (int i = 0; i < numScores; ++i) {
                newScores[i] = (float)Math.exp(this.lambda * (scores.get(i).floatValue() - maxScore.floatValue()));
            }
            float total = 0.0f;
            for (int i = 0; i < numScores; ++i) {
                total += newScores[i];
            }
            float draw = random.uniformUnitInterval();
            float sum = 0.0f;
            actionProbability = 0.0f;
            actionIndex = numScores - 1;
            for (int i = 0; i < numScores; ++i) {
                newScores[i] = newScores[i] / total;
                if (!((sum += newScores[i]) > draw)) continue;
                actionIndex = i;
                actionProbability = newScores[i];
                break;
            }
        }
        return new DecisionTuple(actionIndex + 1, actionProbability, true);
    }

    @Override
    public void enableExplore(boolean explore) {
        this.explore = explore;
    }
}

