/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.uuct;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.uuct.IUCBUtilityFunction;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.api4.java.datastructure.graph.ILabeledPath;

public class UUCBPolicy<N, A>
implements IPathUpdatablePolicy<N, A, Double> {
    private static final double ALPHA = 3.0;
    private final IUCBUtilityFunction utilityFunction;
    private final double a;
    private final double b;
    private final double q;
    private final Map<N, Map<A, DoubleList>> observations = new HashMap<N, Map<A, DoubleList>>();
    private int t = 0;

    public UUCBPolicy(IUCBUtilityFunction utilityFunction) {
        this.utilityFunction = utilityFunction;
        this.a = utilityFunction.getA();
        this.b = utilityFunction.getB();
        this.q = utilityFunction.getQ();
    }

    @Override
    public A getAction(N node, Collection<A> possibleActions) throws ActionPredictionFailedException {
        double bestScore = -1.7976931348623157E308;
        A bestAction = null;
        Map<A, DoubleList> observationsForActions = this.observations.get(node);
        if (observationsForActions == null) {
            return (A)SetUtil.getRandomElement(possibleActions, (long)new Random().nextLong());
        }
        for (A succ : possibleActions) {
            double phiInverse;
            double utility;
            double score;
            DoubleList observationsOfChild = observationsForActions.get(succ);
            if (observationsOfChild == null || !((score = (utility = this.utilityFunction.getUtility(observationsOfChild)) + (phiInverse = this.phiInverse(3.0 * Math.log(this.t) / (double)observationsOfChild.size()))) > bestScore)) continue;
            bestScore = score;
            bestAction = succ;
        }
        if (bestAction == null) {
            return (A)SetUtil.getRandomElement(possibleActions, (long)new Random().nextLong());
        }
        return bestAction;
    }

    private double phiInverse(double x) {
        return Math.max(2.0 * this.b * Math.sqrt(x / this.a), 2.0 * this.b * Math.pow(x / this.a, this.q / 2.0));
    }

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        double playoutScore;
        double s = playoutScore = SetUtil.sum(scores);
        path.getPathToParentOfHead().getNodes().forEach(n -> {
            DoubleList obs = this.observations.computeIfAbsent(n, node -> new HashMap()).computeIfAbsent(path.getOutArc(n), x -> new DoubleArrayList());
            int size = obs.size();
            if (size == 0) {
                obs.add(s);
            } else if (s <= obs.getDouble(0)) {
                obs.add(0, s);
            } else {
                double last = obs.getDouble(0);
                for (int i = 1; i < size; ++i) {
                    double next = obs.getDouble(i);
                    if (playoutScore >= last && playoutScore <= next) {
                        obs.add(i, s);
                        return;
                    }
                    last = next;
                }
            }
        });
        ++this.t;
    }
}

