/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.uct;

import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;
import ai.libs.jaicore.search.algorithms.mdp.mcts.uct.AUpdatingPolicy;
import java.util.Map;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UCBPolicy<T, A>
extends AUpdatingPolicy<T, A>
implements ILoggingCustomizable {
    private String loggerName;
    private Logger logger = LoggerFactory.getLogger(UCBPolicy.class);
    private double explorationConstant;

    public UCBPolicy(double gamma, double explorationConstant, boolean maximize) {
        super(gamma, maximize);
        this.explorationConstant = explorationConstant;
    }

    public UCBPolicy(double gamma, boolean maximize) {
        this(gamma, Math.sqrt(2.0), maximize);
    }

    @Override
    public String getLoggerName() {
        return this.loggerName;
    }

    @Override
    public void setLoggerName(String name) {
        this.loggerName = name;
        super.setLoggerName(name + "._updating");
        this.logger = LoggerFactory.getLogger((String)name);
    }

    public double getEmpiricalMean(T node, A action) {
        NodeLabel<A> nodeLabel = this.getLabelOfNode(node);
        if (nodeLabel == null || nodeLabel.getNumPulls(action) == 0) {
            return (double)(this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
        }
        int timesThisActionHasBeenChosen = nodeLabel.getNumPulls(action);
        return nodeLabel.getAccumulatedRewardsOfAction(action) / (double)timesThisActionHasBeenChosen;
    }

    public double getExplorationTerm(T node, A action) {
        NodeLabel<A> nodeLabel = this.getLabelOfNode(node);
        if (nodeLabel == null || nodeLabel.getNumPulls(action) == 0) {
            return (double)(this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
        }
        int timesThisActionHasBeenChosen = nodeLabel.getNumPulls(action);
        return (double)(this.isMaximize() ? 1 : -1) * this.explorationConstant * Math.sqrt(Math.log(nodeLabel.getVisits()) / (double)timesThisActionHasBeenChosen);
    }

    @Override
    public double getScore(T node, A action) {
        NodeLabel<A> nodeLabel = this.getLabelOfNode(node);
        if (nodeLabel == null || nodeLabel.isVirgin(action)) {
            return (double)(this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
        }
        int timesThisActionHasBeenChosen = nodeLabel.getNumPulls(action);
        double averageScoreForThisAction = nodeLabel.getAccumulatedRewardsOfAction(action) / (double)timesThisActionHasBeenChosen;
        double explorationTerm = (double)(this.isMaximize() ? 1 : -1) * this.explorationConstant * Math.sqrt(Math.log(nodeLabel.getVisits()) / (double)timesThisActionHasBeenChosen);
        double score = averageScoreForThisAction + explorationTerm;
        this.logger.trace("Computed UCB score {} = {} + {} * {} * sqrt(log({})/{}). That is, exploration term is {}", new Object[]{score, averageScoreForThisAction, this.isMaximize() ? 1 : -1, this.explorationConstant, nodeLabel.getVisits(), timesThisActionHasBeenChosen, explorationTerm});
        return score;
    }

    public double getExplorationConstant() {
        return this.explorationConstant;
    }

    public void setExplorationConstant(double explorationConstant) {
        this.explorationConstant = explorationConstant;
    }

    @Override
    public A getActionBasedOnScores(Map<A, Double> scores) {
        A choice = null;
        if (scores.isEmpty()) {
            throw new IllegalArgumentException("An empty set of scored actions has been given to UCB to decide!");
        }
        this.logger.debug("Getting action for scores {}", scores);
        double best = (double)(this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
        for (Map.Entry<A, Double> entry : scores.entrySet()) {
            A action = entry.getKey();
            double score = entry.getValue();
            if (choice == null || this.isMaximize() && score > best || !this.isMaximize() && score < best) {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{choice, action, best});
                best = score;
                choice = action;
                continue;
            }
            this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{action, score, best});
        }
        if (choice == null) {
            throw new IllegalStateException("UCB would return NULL action, which must not be the case!");
        }
        return choice;
    }
}

