/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.uct;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.EBehaviorForNotFullyExploredStates;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AUpdatingPolicy<N, A>
implements IPathUpdatablePolicy<N, A, Double>,
ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(AUpdatingPolicy.class);
    private final double gamma;
    private final boolean maximize;
    private EBehaviorForNotFullyExploredStates behaviorWhenActionForNotFullyExploredStateIsRequested;
    private final Map<N, NodeLabel<A>> labels = new HashMap<N, NodeLabel<A>>();

    public AUpdatingPolicy(double gamma, boolean maximize) {
        this.gamma = gamma;
        this.maximize = maximize;
    }

    public NodeLabel<A> getLabelOfNode(N node) {
        if (!this.labels.containsKey(node)) {
            throw new IllegalArgumentException("No label for node " + node);
        }
        return this.labels.get(node);
    }

    public abstract double getScore(N var1, A var2);

    public abstract A getActionBasedOnScores(Map<A, Double> var1);

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        this.logger.debug("Updating path {} with score {}", path, scores);
        if (path.isPoint()) {
            throw new IllegalArgumentException("Cannot update path consisting only of the root.");
        }
        List nodes = path.getNodes();
        List arcs = path.getArcs();
        int l = nodes.size();
        double accumulatedDiscountedReward = 0.0;
        for (int i = l - 2; i >= 0; --i) {
            Object node = nodes.get(i);
            Object action = arcs.get(i);
            NodeLabel label = this.labels.computeIfAbsent(node, n -> new NodeLabel());
            double rewardForThisAction = scores.get(i) != null ? scores.get(i) : Double.NaN;
            accumulatedDiscountedReward = rewardForThisAction + this.gamma * accumulatedDiscountedReward;
            label.addRewardForAction(action, accumulatedDiscountedReward);
            label.addPull(action);
            label.addVisit();
            this.logger.trace("Updated label of node {}. Visits now {}. Action pulls of {} now {}. Observed total rewards for this action: {}", new Object[]{node, label.getVisits(), action, label.getNumPulls(action), label.getAccumulatedRewardsOfAction(action)});
        }
        this.logger.debug("Path update completed.");
    }

    @Override
    public A getAction(N node, Collection<A> possibleActions) {
        this.logger.debug("Deriving action for node {}. The {} options are: {}", new Object[]{node, possibleActions.size(), possibleActions});
        List actionsThatHaveNotBeenTriedYet = possibleActions.stream().filter(a -> !this.labels.containsKey(node)).collect(Collectors.toList());
        if (!actionsThatHaveNotBeenTriedYet.isEmpty()) {
            if (this.behaviorWhenActionForNotFullyExploredStateIsRequested == EBehaviorForNotFullyExploredStates.EXCEPTION) {
                throw new IllegalStateException("Tree policy should only be consulted for nodes for which each child has been used at least once.");
            }
            if (this.behaviorWhenActionForNotFullyExploredStateIsRequested == EBehaviorForNotFullyExploredStates.BEST) {
                throw new UnsupportedOperationException("Can currently only work with RANDOM or EXCEPTION");
            }
            Object action = actionsThatHaveNotBeenTriedYet.get(0);
            this.logger.info("Dictating action {}, because this was never played before.", action);
            return (A)action;
        }
        NodeLabel<A> labelOfNode = this.labels.get(node);
        this.logger.debug("All actions have been tried. Label is: {}", labelOfNode);
        HashMap<A, Double> scores = new HashMap<A, Double>();
        for (A action : possibleActions) {
            assert (labelOfNode.getVisits() != 0) : "Visits of action " + action + " cannot be 0 if we already used this action before!";
            this.logger.trace("Considering action {}, which has {} visits and cummulative rewards {}.", new Object[]{action, labelOfNode.getNumPulls(action), labelOfNode.getAccumulatedRewardsOfAction(action)});
            Double score = this.getScore(node, action);
            if (score.isNaN()) continue;
            scores.put(action, score);
        }
        if (scores.isEmpty()) {
            this.logger.warn("All children have score NaN. Returning a random one.");
            return (A)SetUtil.getRandomElement(possibleActions, (long)0L);
        }
        Object choice = this.getActionBasedOnScores(scores);
        Objects.requireNonNull(choice, "Would return null, but this must not be the case! Check the method that chooses an action given the scores.");
        this.logger.info("Recommending action {}.", choice);
        return (A)choice;
    }

    public boolean isMaximize() {
        return this.maximize;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
        this.logger.info("Set logger of {} to {}", (Object)this, (Object)name);
    }

    public double getGamma() {
        return this.gamma;
    }

    public EBehaviorForNotFullyExploredStates getBehaviorWhenActionForNotFullyExploredStateIsRequested() {
        return this.behaviorWhenActionForNotFullyExploredStateIsRequested;
    }

    public void setBehaviorWhenActionForNotFullyExploredStateIsRequested(EBehaviorForNotFullyExploredStates behaviorWhenActionForNotFullyExploredStateIsRequested) {
        this.behaviorWhenActionForNotFullyExploredStateIsRequested = behaviorWhenActionForNotFullyExploredStateIsRequested;
    }
}

