/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.comparison;

import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.ToDoubleFunction;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.api4.java.datastructure.graph.ILabeledPath;

public class FixedCommitmentPolicy<N, A>
implements IPathUpdatablePolicy<N, A, Double> {
    private final Map<N, Map<A, DescriptiveStatistics>> observationsPerNode = new HashMap<N, Map<A, DescriptiveStatistics>>();
    private final int k;
    private final ToDoubleFunction<DescriptiveStatistics> metric;

    public FixedCommitmentPolicy(int k, ToDoubleFunction<DescriptiveStatistics> metric) {
        this.k = k;
        this.metric = metric;
    }

    @Override
    public A getAction(N node, Collection<A> actions) throws ActionPredictionFailedException {
        A actionWithLeastVisits = null;
        A actionWithBestVisit = null;
        int numOfVisitsOfThatChild = Integer.MAX_VALUE;
        double bestChildScore = Double.MAX_VALUE;
        for (A action : actions) {
            double bestScoreOfThisChild;
            DescriptiveStatistics observations = this.observationsPerNode.computeIfAbsent(node, n -> new HashMap()).computeIfAbsent(action, a -> new DescriptiveStatistics());
            int numOfVisitsOfThisChild = (int)observations.getN();
            if (numOfVisitsOfThisChild < numOfVisitsOfThatChild) {
                actionWithLeastVisits = action;
                numOfVisitsOfThatChild = numOfVisitsOfThisChild;
            }
            if (!((bestScoreOfThisChild = this.metric.applyAsDouble(observations)) < bestChildScore)) continue;
            bestChildScore = bestScoreOfThisChild;
            actionWithBestVisit = action;
        }
        Objects.requireNonNull(actionWithLeastVisits);
        Objects.requireNonNull(actionWithBestVisit);
        if (numOfVisitsOfThatChild < this.k) {
            return actionWithLeastVisits;
        }
        return actionWithBestVisit;
    }

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        List nodes = path.getNodes();
        List arcs = path.getArcs();
        int l = nodes.size();
        double accumulatedScores = 0.0;
        for (int i = l - 2; i >= 0; --i) {
            Object node = nodes.get(i);
            Object action = arcs.get(i);
            DescriptiveStatistics statsForNodeActionPair = this.observationsPerNode.computeIfAbsent(node, n -> new HashMap()).computeIfAbsent(action, a -> new DescriptiveStatistics());
            statsForNodeActionPair.addValue(accumulatedScores += scores.get(i).doubleValue());
        }
    }
}

