/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.probleminputs;

import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.model.other.SearchGraphPath;
import ai.libs.jaicore.search.probleminputs.IMDP;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IEvaluatedPath;
import org.api4.java.common.attributedobjects.ObjectEvaluationFailedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MDPUtils
implements ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(MDPUtils.class);

    public static <N, A> Collection<N> getStates(IMDP<N, A, ?> mdp) throws InterruptedException {
        HashSet states = new HashSet();
        ArrayDeque<N> open = new ArrayDeque<N>();
        open.add(mdp.getInitState());
        while (!open.isEmpty()) {
            Object next = open.pop();
            if (states.contains(next)) continue;
            states.add(next);
            for (A a : mdp.getApplicableActions(next)) {
                open.addAll(mdp.getProb(next, a).keySet());
            }
        }
        return states;
    }

    public <N, A> N drawSuccessorState(IMDP<N, A, ?> mdp, N state, A action) throws InterruptedException {
        return this.drawSuccessorState(mdp, state, action, new Random());
    }

    public <N, A> N drawSuccessorState(IMDP<N, A, ?> mdp, N state, A action, Random rand) throws InterruptedException {
        if (!mdp.isActionApplicableInState(state, action)) {
            throw new IllegalArgumentException("Action " + action + " is not applicable in " + state);
        }
        Map<N, Double> dist = mdp.getProb(state, action);
        double p = rand.nextDouble();
        double s = 0.0;
        for (Map.Entry<N, Double> neighborWithProb : dist.entrySet()) {
            if (!((s += neighborWithProb.getValue().doubleValue()) >= p)) continue;
            return neighborWithProb.getKey();
        }
        throw new IllegalStateException("The accumulated probability of all the " + dist.size() + " successors is only " + s + " instead of 1.\n\tState: " + state + "\n\tAction: " + action + "\nConsidered successor states: " + dist.entrySet().stream().map(e -> "\n\t" + e.toString()).collect(Collectors.joining()));
    }

    public <N, A> IEvaluatedPath<N, A, Double> getRun(IMDP<N, A, Double> mdp, double gamma, IPolicy<N, A> policy, Random random, Predicate<ILabeledPath<N, A>> stopCriterion) throws InterruptedException, ActionPredictionFailedException, ObjectEvaluationFailedException {
        double score = 0.0;
        SearchGraphPath path = new SearchGraphPath(mdp.getInitState());
        Object current = path.getRoot();
        Collection<A> possibleActions = mdp.getApplicableActions(current);
        double discount = 1.0;
        while (!possibleActions.isEmpty() && !stopCriterion.test(path)) {
            A action = policy.getAction(current, possibleActions);
            assert (possibleActions.contains(action));
            Object nextState = this.drawSuccessorState(mdp, current, action, random);
            this.logger.debug("Choosing action {}. Next state is {} (probability is {})", new Object[]{action, nextState, mdp.getProb(current, action, nextState)});
            score += discount * mdp.getScore(current, action, nextState);
            discount *= gamma;
            current = nextState;
            path.extend(current, action);
            possibleActions = mdp.getApplicableActions(current);
        }
        return new EvaluatedSearchGraphPath(path, score);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }

    public static int getTimeHorizon(double gamma, double epsilon) {
        return gamma < 1.0 ? (int)Math.ceil(Math.log(epsilon) / Math.log(gamma)) : Integer.MAX_VALUE;
    }
}

