/*
 * Decompiled with CFR 0.152.
 */
package com.mwt.explorers;

import com.mwt.consumers.ConsumePolicy;
import com.mwt.explorers.Explorer;
import com.mwt.misc.DecisionTuple;
import com.mwt.policies.Policy;
import com.mwt.utilities.PRG;

public class EpsilonGreedyExplorer<T>
implements Explorer<T>,
ConsumePolicy<T> {
    private Policy<T> defaultPolicy;
    private final float epsilon;
    private boolean explore = true;
    private int numActions;

    public EpsilonGreedyExplorer(Policy<T> defaultPolicy, float epsilon, int numActions) {
        if (numActions < 1) {
            throw new IllegalArgumentException("Number of actions must be at least 1.");
        }
        if (epsilon < 0.0f || epsilon > 1.0f) {
            throw new IllegalArgumentException("Epsilon must be between 0 and 1.");
        }
        this.defaultPolicy = defaultPolicy;
        this.epsilon = epsilon;
        this.numActions = numActions;
    }

    protected int getNumActions(T context) {
        return this.numActions;
    }

    @Override
    public void updatePolicy(Policy<T> newPolicy) {
        this.defaultPolicy = newPolicy;
    }

    @Override
    public DecisionTuple chooseAction(long saltedSeed, T context) {
        int numActionsForContext = this.getNumActions(context);
        PRG random = new PRG(saltedSeed);
        int chosenAction = this.defaultPolicy.chooseAction(context);
        if (chosenAction <= 0 || chosenAction > numActionsForContext) {
            throw new RuntimeException("Action chosen by default policy is not within valid range.");
        }
        float epsilon = this.explore ? this.epsilon : 0.0f;
        float actionProbability = 0.0f;
        float baseProbability = epsilon / (float)numActionsForContext;
        if (random.uniformUnitInterval() < 1.0f - epsilon) {
            actionProbability = 1.0f - epsilon + baseProbability;
        } else {
            int actionId = random.uniformInt(1, numActionsForContext);
            actionProbability = actionId == chosenAction ? 1.0f - epsilon + baseProbability : baseProbability;
            chosenAction = actionId;
        }
        return new DecisionTuple(chosenAction, actionProbability, true);
    }

    @Override
    public void enableExplore(boolean explore) {
        this.explore = explore;
    }
}

