/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.tag;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TAGPolicy<T, A>
implements IPathUpdatablePolicy<T, A, Double>,
ILoggingCustomizable {
    private String loggerName;
    private Logger logger = LoggerFactory.getLogger(TAGPolicy.class);
    private double explorationConstant = Math.sqrt(2.0);
    private final int s;
    private final Map<T, Double> thresholdPerNode = new HashMap<T, Double>();
    private final double delta;
    private final double thresholdIncrement;
    private final boolean isMaximize;
    private final Map<T, Map<A, PriorityQueue<Double>>> statsPerNode = new HashMap<T, Map<A, PriorityQueue<Double>>>();
    private final Map<T, Map<A, Integer>> pullsPerNodeAction = new HashMap<T, Map<A, Integer>>();
    private final Map<T, Integer> visitsPerNode = new HashMap<T, Integer>();

    public TAGPolicy() {
        this(false);
    }

    public TAGPolicy(double explorationConstant, int s, double delta, double thresholdIncrement, boolean isMaximize) {
        this.explorationConstant = explorationConstant;
        this.s = s;
        this.delta = delta;
        this.thresholdIncrement = thresholdIncrement;
        this.isMaximize = isMaximize;
    }

    public TAGPolicy(boolean maximize) {
        this(Math.sqrt(2.0), 10, 1.0, 0.01, maximize);
    }

    @Override
    public A getAction(T node, Collection<A> actions) throws ActionPredictionFailedException {
        this.logger.info("Getting action for node {}", node);
        Map pullMap = this.pullsPerNodeAction.computeIfAbsent(node, n -> new HashMap());
        actions.forEach(a -> pullMap.computeIfAbsent(a, action -> 1));
        this.visitsPerNode.put(node, this.visitsPerNode.computeIfAbsent(node, n -> 0) + 1);
        this.logger.debug("Adjusting threshold.");
        this.adjustThreshold(node);
        this.logger.debug("Threshold adjusted. Is now {}", (Object)this.thresholdPerNode.get(node));
        A choice = null;
        double best = (double)(this.isMaximize ? -1 : 1) * Double.MAX_VALUE;
        int k = actions.size();
        for (A action : actions) {
            double score = this.getUtilityOfAction(node, action, k);
            if (!Double.isNaN(score) && (this.isMaximize && score > best || !this.isMaximize && score < best)) {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{choice, action, best});
                best = score;
                choice = action;
                continue;
            }
            this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{action, score, best});
        }
        if (choice == null) {
            this.logger.warn("All options have score NaN. Returning random element.");
            return (A)SetUtil.getRandomElement(actions, (long)0L);
        }
        pullMap.put(choice, (Integer)pullMap.get(choice) + 1);
        return choice;
    }

    public void adjustThreshold(T node) {
        int sum;
        Map<A, PriorityQueue<Double>> observations = this.statsPerNode.get(node);
        double t = this.thresholdPerNode.computeIfAbsent(node, n -> this.isMaximize ? 0.0 : 100.0);
        this.logger.debug("Initial value for threshold is {}. Observations are: {}", (Object)t, observations);
        if (observations == null) {
            return;
        }
        boolean first = true;
        do {
            if (!first) {
                t += this.thresholdIncrement * (double)(this.isMaximize ? 1 : -1);
            }
            sum = 0;
            for (Map.Entry<A, PriorityQueue<Double>> entry : observations.entrySet()) {
                double localT = t;
                entry.getValue().removeIf(d -> this.isMaximize && d < localT || !this.isMaximize && d > localT);
                sum += entry.getValue().size();
            }
            first = false;
        } while ((double)sum != Double.NaN && sum > this.s);
        this.logger.debug("Setting threshold to {}", (Object)t);
        this.thresholdPerNode.put(node, t);
    }

    public double getUtilityOfAction(T node, A action, int k) {
        if (!this.statsPerNode.containsKey(node) || !this.statsPerNode.get(node).containsKey(action)) {
            return Double.NaN;
        }
        double alpha = Math.log((double)(2 * this.visitsPerNode.get(node) * k) / this.delta);
        Queue memorizedScoresForArm = this.statsPerNode.get(node).get(action);
        int sChild = memorizedScoresForArm.size();
        if (alpha < 0.0) {
            throw new IllegalStateException("Alpha must not be negative. Check delta value (must be smaller than 1)");
        }
        double nominator = (double)sChild + alpha + Math.sqrt((double)(2 * sChild) * alpha + Math.pow(alpha, 2.0));
        int armPulls = this.pullsPerNodeAction.get(node).get(action);
        if (armPulls == 0) {
            throw new IllegalArgumentException("Cannot compute score for child with no visits!");
        }
        double h = nominator / (double)armPulls;
        this.logger.trace("Compute TAG score of {}", (Object)h);
        return h;
    }

    public double getExplorationConstant() {
        return this.explorationConstant;
    }

    public void setExplorationConstant(double explorationConstant) {
        this.explorationConstant = explorationConstant;
    }

    @Override
    public void updatePath(ILabeledPath<T, A> path, List<Double> scores) {
        int l = path.getNumberOfNodes() - 1;
        List nodes = path.getNodes();
        List arcs = path.getArcs();
        double accumulatedScores = 0.0;
        for (int i = l - 1; i >= 0; --i) {
            Object node = nodes.get(i);
            Object action = arcs.get(i);
            Map actionMap = this.statsPerNode.computeIfAbsent(node, n -> new HashMap());
            PriorityQueue bestScores = actionMap.computeIfAbsent(action, a -> this.isMaximize ? new PriorityQueue((c1, c2) -> Double.compare(c2, c1)) : new PriorityQueue());
            assert (!bestScores.contains(Double.NaN));
            if (accumulatedScores != Double.NaN && scores.get(i) != null) {
                accumulatedScores += scores.get(i).doubleValue();
            } else if (!Double.isNaN(accumulatedScores)) {
                accumulatedScores = Double.NaN;
            }
            if (Double.isNaN(accumulatedScores)) {
                return;
            }
            if (bestScores.size() < this.s) {
                bestScores.add(accumulatedScores);
            } else if ((Double)bestScores.peek() < accumulatedScores) {
                bestScores.poll();
                bestScores.add(accumulatedScores);
            }
            assert (!bestScores.contains(Double.NaN));
        }
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String name) {
        this.loggerName = name;
        this.logger = LoggerFactory.getLogger((String)name);
    }
}

