/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.spuct;

import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;
import ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SPUCBPolicy<N, A>
extends UCBPolicy<N, A>
implements ILoggingCustomizable {
    private String loggerName;
    private Logger logger = LoggerFactory.getLogger(SPUCBPolicy.class);
    private final double bigD;
    private Map<NodeLabel<A>, Double> squaredObservations = new HashMap<NodeLabel<A>, Double>();

    public SPUCBPolicy(double gamma, double bigD) {
        this(gamma, true, bigD);
    }

    public SPUCBPolicy(double gamma, boolean maximize, double bigD) {
        super(gamma, maximize);
        this.bigD = bigD;
    }

    @Override
    public String getLoggerName() {
        return this.loggerName;
    }

    @Override
    public void setLoggerName(String name) {
        this.loggerName = name;
        super.setLoggerName(name + "._updating");
        this.logger = LoggerFactory.getLogger((String)name);
    }

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        super.updatePath(path, scores);
        List nodes = path.getNodes();
        int l = nodes.size();
        double accumulatedScores = 0.0;
        for (int i = l - 2; i >= 0; --i) {
            NodeLabel nl = this.getLabelOfNode(nodes.get(i));
            if (!Double.isNaN(accumulatedScores) && scores.get(i) != null) {
                accumulatedScores = scores.get(i) + this.getGamma() * accumulatedScores;
            } else if (!Double.isNaN(accumulatedScores)) {
                accumulatedScores = Double.NaN;
            }
            this.squaredObservations.put(nl, this.squaredObservations.computeIfAbsent(nl, label -> 0.0) + Math.pow(accumulatedScores, 2.0));
        }
    }

    @Override
    public double getScore(N node, A action) {
        double ucbMean = super.getEmpiricalMean(node, action);
        double ucbExploration = super.getEmpiricalMean(node, action);
        double ucb = ucbMean + ucbExploration;
        NodeLabel<A> labelOfNode = this.getLabelOfNode(node);
        int visitsOfChild = labelOfNode.getNumPulls(action);
        double squaredResults = this.squaredObservations.containsKey(labelOfNode) ? this.squaredObservations.get(labelOfNode) : 0.0;
        double expectedResults = (double)visitsOfChild * Math.pow(ucbMean, 2.0);
        double spTerm = (double)(this.isMaximize() ? 1 : -1) * Math.sqrt((squaredResults - expectedResults + this.bigD) / (double)visitsOfChild);
        double score = ucb + spTerm;
        this.logger.debug("Computed score for action {}: {} = {} + {}", new Object[]{action, score, ucb, spTerm});
        return score;
    }
}

