/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.ensemble;

import ai.libs.jaicore.graph.LabeledGraph;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IGraphDependentPolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import org.api4.java.datastructure.graph.ILabeledPath;

public class EnsembleTreePolicy<N, A>
implements IPathUpdatablePolicy<N, A, Double>,
IGraphDependentPolicy<N, A> {
    private final List<IPathUpdatablePolicy<N, A, Double>> treePolicies;
    private final Random rand = new Random(0L);
    private IPolicy<N, A> lastPolicy;
    private Map<IPolicy<N, A>, Double> meansOfObservations = new HashMap<IPolicy<N, A>, Double>();
    private Map<IPolicy<N, A>, Integer> numberOfTimesChosen = new HashMap<IPolicy<N, A>, Integer>();
    private int calls;

    public EnsembleTreePolicy(Collection<? extends IPathUpdatablePolicy<N, A, Double>> treePolicies) {
        this.treePolicies = new ArrayList<IPathUpdatablePolicy<N, A, Double>>(treePolicies);
    }

    @Override
    public A getAction(N node, Collection<A> actions) throws ActionPredictionFailedException, InterruptedException {
        ++this.calls;
        if (this.rand.nextDouble() < 1.1) {
            this.lastPolicy = this.treePolicies.get(this.rand.nextInt(this.treePolicies.size()));
            return this.lastPolicy.getAction(node, actions);
        }
        double bestScore = Double.MAX_VALUE;
        IPolicy bestPolicy = null;
        for (IPolicy iPolicy : this.treePolicies) {
            double score;
            if (!this.numberOfTimesChosen.containsKey(iPolicy)) {
                score = 0.0;
            } else {
                double explorationTerm = -1.0 * Math.sqrt(2.0) * Math.sqrt(Math.log(this.calls) / (double)this.numberOfTimesChosen.get(iPolicy).intValue());
                score = this.meansOfObservations.get(iPolicy) + explorationTerm;
            }
            if (!(score < bestScore)) continue;
            bestScore = score;
            bestPolicy = iPolicy;
        }
        Objects.requireNonNull(bestPolicy);
        this.lastPolicy = bestPolicy;
        return bestPolicy.getAction(node, actions);
    }

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        for (IPathUpdatablePolicy<N, A, Double> policy : this.treePolicies) {
            policy.updatePath(path, scores);
        }
        int visits = this.numberOfTimesChosen.computeIfAbsent(this.lastPolicy, p -> 0);
        this.numberOfTimesChosen.put(this.lastPolicy, visits + 1);
        double playoutScore = (Double)scores.stream().reduce((a, b) -> a + b).get();
        this.meansOfObservations.put(this.lastPolicy, (this.meansOfObservations.computeIfAbsent(this.lastPolicy, p -> 0.0) * (double)visits + playoutScore) / (double)(visits + 1));
    }

    @Override
    public void setGraph(LabeledGraph<N, A> graph) {
        for (IPathUpdatablePolicy<N, A, Double> policy : this.treePolicies) {
            if (!(policy instanceof IGraphDependentPolicy)) continue;
            ((IGraphDependentPolicy)((Object)policy)).setGraph(graph);
        }
    }
}

