/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.brue;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.api4.java.datastructure.graph.ILabeledPath;

public class BRUEPolicy<N, A>
implements IPathUpdatablePolicy<N, A, Double> {
    private final Map<Pair<N, A>, Integer> nCounter = new HashMap<Pair<N, A>, Integer>();
    private final Map<Pair<N, A>, Double> qHat = new HashMap<Pair<N, A>, Double>();
    private final Random random;
    private final int timeHorizon;
    private final boolean maximize;
    private int n = 0;

    public BRUEPolicy(boolean maximize, int timeHorizon, Random random) {
        this.maximize = maximize;
        this.timeHorizon = timeHorizon;
        this.random = random;
    }

    public BRUEPolicy(boolean maximize) {
        this(maximize, 1000, new Random(0L));
    }

    @Override
    public A getAction(N node, Collection<A> actions) throws ActionPredictionFailedException {
        double bestScore;
        if (actions.isEmpty()) {
            throw new IllegalArgumentException();
        }
        double worstValue = bestScore = (double)(this.maximize ? -1 : 1) * Double.MAX_VALUE;
        ArrayList<Object> bestActions = new ArrayList<Object>();
        for (A action : actions) {
            double score;
            Pair pair = new Pair(node, action);
            double d = score = this.qHat.containsKey(pair) ? this.qHat.get(pair) : worstValue;
            if (score < bestScore) {
                bestActions.clear();
                bestScore = score;
                bestActions.add(pair.getY());
                continue;
            }
            if (score != bestScore) continue;
            bestActions.add(pair.getY());
        }
        if (bestActions.isEmpty()) {
            throw new IllegalStateException();
        }
        if (bestActions.size() > 1) {
            Collections.shuffle(bestActions, this.random);
        }
        Object choice = bestActions.get(0);
        Pair pair = new Pair(node, choice);
        this.nCounter.put(pair, this.nCounter.computeIfAbsent(pair, p -> 0) + 1);
        return (A)choice;
    }

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        double currentScore;
        int sigmaN = this.getSwitchingPoint(this.n);
        if (sigmaN < 0) {
            throw new IllegalStateException("The switching point index must NOT be negative!");
        }
        ++this.n;
        int l = path.getNumberOfNodes();
        if (sigmaN > l - 2) {
            return;
        }
        List nodes = path.getNodes();
        List arcs = path.getArcs();
        Object node = nodes.get(sigmaN - 1);
        Object arc = arcs.get(sigmaN - 1);
        Pair updatedPair = new Pair(node, arc);
        double worstValue = (double)(this.maximize ? -1 : 1) * Double.MAX_VALUE;
        double d = currentScore = this.qHat.containsKey(updatedPair) ? this.qHat.get(updatedPair) : worstValue;
        if (!this.nCounter.containsKey(updatedPair)) {
            throw new IllegalStateException("No visit stats for updated pair " + updatedPair + " available.");
        }
        double observedRewardsFromTheUpdatedAction = 0.0;
        for (int i = l - 2; i >= sigmaN - 1; --i) {
            observedRewardsFromTheUpdatedAction += scores.get(i).doubleValue();
        }
        double newScore = currentScore + (observedRewardsFromTheUpdatedAction - currentScore) / (double)this.nCounter.get(updatedPair).intValue();
        this.qHat.put(updatedPair, newScore);
    }

    public int getSwitchingPoint(int n) {
        return this.timeHorizon - n % this.timeHorizon;
    }
}

