/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts;

import ai.libs.jaicore.basic.algorithm.AAlgorithm;
import ai.libs.jaicore.basic.algorithm.AlgorithmFinishedEvent;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.graphvisualizer.events.graph.GraphInitializedEvent;
import ai.libs.jaicore.graphvisualizer.events.graph.NodeAddedEvent;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IRolloutLimitDependentPolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSIterationCompletedEvent;
import ai.libs.jaicore.search.algorithms.mdp.mcts.UniformRandomPolicy;
import ai.libs.jaicore.search.model.other.SearchGraphPath;
import ai.libs.jaicore.search.probleminputs.IMDP;
import ai.libs.jaicore.search.probleminputs.MDPUtils;
import ai.libs.jaicore.timing.TimedComputation;
import com.google.common.eventbus.Subscribe;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.api4.java.algorithm.IAlgorithm;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.attributedobjects.ObjectEvaluationFailedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IEvent;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MCTS<N, A>
extends AAlgorithm<IMDP<N, A, Double>, IPolicy<N, A>> {
    private Logger logger = LoggerFactory.getLogger(MCTS.class);
    private static final Runtime runtime = Runtime.getRuntime();
    private final IMDP<N, A, Double> mdp;
    private final int maxDepth;
    private final MDPUtils utils = new MDPUtils();
    private final IPathUpdatablePolicy<N, A, Double> treePolicy;
    private final IPolicy<N, A> defaultPolicy;
    private final boolean uniformSamplingDefaultPolicy;
    private final Random randomSourceOfUniformSamplyPolicy;
    private final int maxIterations;
    private int iterations = 0;
    private final Collection<N> tpReadyStates = new HashSet<N>();
    private final Map<N, Collection<A>> applicableActionsPerState = new HashMap<N, Collection<A>>();
    private final Map<N, List<A>> untriedActionsOfIncompleteStates = new HashMap<N, List<A>>();
    private int lastProgressReport = 0;
    private int msSpentInRollouts;
    private int msSpentInTreePolicyQueries;
    private int msSpentInTreePolicyUpdates;
    private final boolean tabooExhaustedNodes;
    private Map<N, Collection<A>> tabooActions = new HashMap<N, Collection<A>>();
    private ILabeledPath<N, A> enforcedPrefixPath = null;

    public MCTS(IMDP<N, A, Double> input, IPathUpdatablePolicy<N, A, Double> treePolicy, IPolicy<N, A> defaultPolicy, int maxIterations, double gamma, double epsilon, boolean tabooExhaustedNodes) {
        super(input);
        Objects.requireNonNull(input);
        Objects.requireNonNull(treePolicy);
        Objects.requireNonNull(defaultPolicy);
        this.mdp = input;
        this.treePolicy = treePolicy;
        this.defaultPolicy = defaultPolicy;
        this.uniformSamplingDefaultPolicy = defaultPolicy instanceof UniformRandomPolicy;
        this.randomSourceOfUniformSamplyPolicy = this.uniformSamplingDefaultPolicy ? ((UniformRandomPolicy)defaultPolicy).getRandom() : null;
        this.maxIterations = maxIterations;
        this.maxDepth = MDPUtils.getTimeHorizon(gamma, epsilon);
        this.tabooExhaustedNodes = tabooExhaustedNodes;
        if (treePolicy instanceof IRelaxedEventEmitter) {
            ((IRelaxedEventEmitter)treePolicy).registerListener(new Object(){

                @Subscribe
                public void receiveEvent(IEvent event) {
                    MCTS.this.post(event);
                }
            });
        }
    }

    public List<A> getPotentialActions(ILabeledPath<N, A> path, Collection<A> applicableActions) {
        Object current = path.getHead();
        List<Object> possibleActions = new ArrayList<A>(applicableActions);
        if (possibleActions.isEmpty()) {
            this.logger.warn("Computing potential actions for an empty set of applicable actions makes no sense! Returning an empty set for node {}.", current);
            return possibleActions;
        }
        this.logger.debug("Computing potential actions based on {} applicable ones for state {}", (Object)applicableActions.size(), current);
        if (this.tabooExhaustedNodes) {
            Collection tabooActionsForThisState = this.tabooActions.get(current);
            this.logger.debug("Found {} tabooed actions for this state.", (Object)(tabooActionsForThisState != null ? tabooActionsForThisState.size() : 0));
            if (tabooActionsForThisState != null) {
                possibleActions = possibleActions.stream().filter(a -> !tabooActionsForThisState.contains(a)).collect(Collectors.toList());
            }
            if (possibleActions.isEmpty() && path.getNumberOfNodes() > 1) {
                this.tabooLastActionOfPath(path);
            }
        }
        return possibleActions;
    }

    private Collection<A> getApplicableActions(N state) throws AlgorithmTimeoutedException, ExecutionException, InterruptedException, AlgorithmExecutionCanceledException {
        Timeout toForSuccessorComputation = new Timeout(this.getRemainingTimeToDeadline().milliseconds() - 1000L, TimeUnit.MILLISECONDS);
        this.logger.debug("Computing all applicable actions with timeout {}.", (Object)toForSuccessorComputation);
        try {
            Collection applicableActions = Collections.unmodifiableCollection((Collection)TimedComputation.compute(() -> this.mdp.getApplicableActions(state), (Timeout)toForSuccessorComputation, (String)"Timeout bound hit."));
            this.logger.debug("Number of applicable actions is {}", (Object)applicableActions.size());
            return applicableActions;
        }
        catch (InterruptedException e) {
            this.checkAndConductTermination();
            throw e;
        }
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        this.logger.debug("Stepping MCTS with thread {}", (Object)Thread.currentThread());
        this.registerActiveThread();
        try {
            switch (this.getState()) {
                case CREATED: {
                    this.logger.info("Initialized MCTS algorithm {}.\n\tTree Policy: {}\n\tDefault Policy: {}\n\tMax Iterations: {}\n\tMax Depth: {}\n\tTaboo Exhausted Nodes: {}", new Object[]{((Object)((Object)this)).getClass().getName(), this.treePolicy, this.defaultPolicy, this.maxIterations, this.maxDepth, this.tabooExhaustedNodes});
                    AlgorithmInitializedEvent algorithmInitializedEvent = this.activate();
                    return algorithmInitializedEvent;
                }
                case ACTIVE: {
                    int progress;
                    if (this.iterations >= this.maxIterations) {
                        this.logger.info("Number of iterations reached limit of {}.", (Object)this.maxIterations);
                        AlgorithmFinishedEvent algorithmFinishedEvent = this.terminate();
                        return algorithmFinishedEvent;
                    }
                    long timeStart = System.currentTimeMillis();
                    ++this.iterations;
                    if (this.treePolicy instanceof IRolloutLimitDependentPolicy && this.isTimeoutDefined()) {
                        double avgTimeOfRollouts = (double)this.msSpentInRollouts * 1.0 / (double)this.iterations;
                        int expectedRemainingNumberOfRollouts = (int)Math.floor((double)this.getRemainingTimeToDeadline().milliseconds() / avgTimeOfRollouts);
                        ((IRolloutLimitDependentPolicy)((Object)this.treePolicy)).setEstimatedNumberOfRemainingRollouts(expectedRemainingNumberOfRollouts);
                    }
                    this.logger.info("Draw next playout: #{}.", (Object)this.iterations);
                    int invocationsOfTreePolicyInThisIteration = 0;
                    int invocationsOfDefaultPolicyInThisIteration = 0;
                    long timeSpentInActionApplicabilityComputationThisIteration = 0L;
                    long timeSpentInSuccessorGenerationThisIteration = 0L;
                    long timeSpentInTreePolicyQueriesThisIteration = 0L;
                    long timeSpentInTreePolicyUpdatesThisIteration = 0L;
                    long timeSpentInDefaultPolicyThisIteration = 0L;
                    ArrayList<Double> scores = new ArrayList<Double>();
                    SearchGraphPath path = new SearchGraphPath(this.mdp.getInitState());
                    Object current = path.getRoot();
                    Object action = null;
                    int phase = 1;
                    long lastTerminationCheck = 0L;
                    int depth = 0;
                    while (path.getNumberOfNodes() < this.maxDepth && !this.mdp.isTerminalState(current)) {
                        this.logger.debug("Now extending the roll-out in depth {}", (Object)depth);
                        ++depth;
                        long now = System.currentTimeMillis();
                        if (now - lastTerminationCheck > 1000L) {
                            this.checkAndConductTermination();
                            lastTerminationCheck = now;
                        }
                        if (phase == 1 && this.tpReadyStates.contains(current)) {
                            this.logger.debug("Computing possible actions for node {}", current);
                            assert (this.applicableActionsPerState.containsKey(current) && !this.applicableActionsPerState.get(current).isEmpty()) : "It makes no sense to apply the TP to a node without applicable actions!";
                            List possibleActions = this.getPotentialActions(path, this.applicableActionsPerState.get(current));
                            if (possibleActions.isEmpty()) {
                                if (!path.isPoint()) break;
                                this.logger.info("There are no possible actions in the root. Finishing.");
                                this.summarizeIteration(System.currentTimeMillis() - timeStart, timeSpentInActionApplicabilityComputationThisIteration, timeSpentInSuccessorGenerationThisIteration, invocationsOfTreePolicyInThisIteration, invocationsOfDefaultPolicyInThisIteration, timeSpentInTreePolicyQueriesThisIteration, timeSpentInTreePolicyUpdatesThisIteration, timeSpentInDefaultPolicyThisIteration);
                                AlgorithmFinishedEvent algorithmFinishedEvent = this.terminate();
                                return algorithmFinishedEvent;
                            }
                            this.logger.debug("Ask tree policy to choose one action of: {}.", possibleActions);
                            long tpStart = System.currentTimeMillis();
                            action = this.treePolicy.getAction(current, possibleActions);
                            timeSpentInTreePolicyQueriesThisIteration += System.currentTimeMillis() - tpStart;
                            ++invocationsOfTreePolicyInThisIteration;
                            Objects.requireNonNull(action, "Actions in MCTS must never be null, but tree policy returned null!");
                            this.logger.debug("Tree policy recommended action {}.", action);
                        } else {
                            if (phase == 1) {
                                this.logger.debug("Switching to roll-out phase 2.");
                                phase = 2;
                            }
                            if (phase == 2) {
                                List<A> untriedActions;
                                if (!this.untriedActionsOfIncompleteStates.containsKey(current)) {
                                    long startActionTime = System.currentTimeMillis();
                                    if (this.getRemainingTimeToDeadline().milliseconds() < 2000L) {
                                        if (this.getRemainingTimeToDeadline().milliseconds() > 0L) {
                                            Thread.sleep(this.getRemainingTimeToDeadline().milliseconds());
                                        }
                                        this.checkAndConductTermination();
                                    }
                                    Collection<A> applicableActions = this.getApplicableActions(current);
                                    untriedActions = new ArrayList<A>(applicableActions);
                                    timeSpentInActionApplicabilityComputationThisIteration += System.currentTimeMillis() - startActionTime;
                                    this.applicableActionsPerState.put(current, applicableActions);
                                    if (untriedActions.isEmpty()) {
                                        long tpStart = System.currentTimeMillis();
                                        this.treePolicy.updatePath(path, scores);
                                        MCTSIterationCompletedEvent event = new MCTSIterationCompletedEvent((IAlgorithm<?, ?>)this, this.treePolicy, new SearchGraphPath(path), scores);
                                        this.post((Object)event);
                                        this.summarizeIteration(System.currentTimeMillis() - timeStart, timeSpentInActionApplicabilityComputationThisIteration, timeSpentInSuccessorGenerationThisIteration, invocationsOfTreePolicyInThisIteration, invocationsOfDefaultPolicyInThisIteration, timeSpentInTreePolicyQueriesThisIteration, timeSpentInTreePolicyUpdatesThisIteration += System.currentTimeMillis() - tpStart, timeSpentInDefaultPolicyThisIteration);
                                        MCTSIterationCompletedEvent mCTSIterationCompletedEvent = event;
                                        return mCTSIterationCompletedEvent;
                                    }
                                    this.untriedActionsOfIncompleteStates.put(current, untriedActions);
                                } else {
                                    untriedActions = this.untriedActionsOfIncompleteStates.get(current);
                                }
                                this.logger.debug("There are {} untried actions: {}", (Object)untriedActions.size(), untriedActions);
                                assert (!untriedActions.isEmpty()) : "Untried actions must not be empty!";
                                action = untriedActions.remove(0);
                                this.logger.debug("Choosing untried action {}. There are {} remaining untried actions: {}", new Object[]{action, untriedActions.size(), untriedActions});
                                Objects.requireNonNull(action, "Actions in MCTS must never be null!");
                                if (untriedActions.isEmpty()) {
                                    this.untriedActionsOfIncompleteStates.remove(current);
                                    this.tpReadyStates.add(current);
                                    if (path.isPoint()) {
                                        this.post(new GraphInitializedEvent((IAlgorithm)this, current));
                                    } else {
                                        this.post(new NodeAddedEvent((IAlgorithm)this, path.getPathToParentOfHead().getHead(), current, "none"));
                                    }
                                    this.logger.debug("Adding state {} to tree policy domain.", current);
                                }
                                phase = 3;
                                this.logger.debug("Switching to roll-out phase 3.");
                            } else if (phase == 3) {
                                long startDP = System.currentTimeMillis();
                                if (this.uniformSamplingDefaultPolicy) {
                                    this.logger.debug("Sample a single action directly from the MDP.");
                                    action = this.mdp.getUniformlyRandomApplicableAction(current, this.randomSourceOfUniformSamplyPolicy);
                                } else {
                                    long startActionTime = System.currentTimeMillis();
                                    Collection<Object> applicableActions = this.getApplicableActions(current);
                                    timeSpentInActionApplicabilityComputationThisIteration += System.currentTimeMillis() - startActionTime;
                                    this.logger.debug("Ask default policy to choose one action of: {}.", applicableActions);
                                    action = this.defaultPolicy.getAction(current, applicableActions);
                                    assert (applicableActions.contains(action));
                                }
                                timeSpentInDefaultPolicyThisIteration += System.currentTimeMillis() - startDP;
                                ++invocationsOfDefaultPolicyInThisIteration;
                                Objects.requireNonNull(action, "Actions in MCTS must never be null, but default policy has returned null!");
                                this.logger.debug("Default policy chose action {}.", action);
                            } else {
                                throw new IllegalStateException("Invalid phase " + phase);
                            }
                        }
                        long startSuccessorComputation = System.currentTimeMillis();
                        Object nextState = this.utils.drawSuccessorState(this.mdp, current, action);
                        timeSpentInSuccessorGenerationThisIteration += System.currentTimeMillis() - startSuccessorComputation;
                        scores.add(this.mdp.getScore(current, action, nextState));
                        current = nextState;
                        path.extend(current, action);
                    }
                    if (this.tabooExhaustedNodes && phase == 1) {
                        this.tabooLastActionOfPath(path);
                    }
                    if ((progress = (int)Math.round((double)this.iterations * 100.0 / (double)this.maxIterations)) > this.lastProgressReport && progress % 5 == 0) {
                        this.logger.info("Progress: {}%", (Object)Math.round((double)this.iterations * 100.0 / (double)this.maxIterations));
                        this.lastProgressReport = progress;
                    }
                    boolean hasNullScore = scores.contains(null);
                    boolean isGoalPath = this.mdp.isTerminalState(path.getHead());
                    double totalUndiscountedScore = hasNullScore ? Double.NaN : scores.stream().reduce(0.0, (a, b) -> a + b);
                    this.logger.info("Found playout of length {}. Head is goal: {}. (Undiscounted) score of path is {}.", new Object[]{path.getNumberOfNodes(), isGoalPath, totalUndiscountedScore});
                    this.logger.debug("Found leaf node with score {}. Now propagating this score over the path with actions {}. Leaf state is: {}.", new Object[]{totalUndiscountedScore, path.getArcs(), path.getHead()});
                    if (!path.isPoint()) {
                        long tpStart = System.currentTimeMillis();
                        this.treePolicy.updatePath(path, scores);
                        timeSpentInTreePolicyUpdatesThisIteration += System.currentTimeMillis() - tpStart;
                    }
                    MCTSIterationCompletedEvent event = new MCTSIterationCompletedEvent((IAlgorithm<?, ?>)this, this.treePolicy, new SearchGraphPath(path), scores);
                    this.post((Object)event);
                    this.summarizeIteration(System.currentTimeMillis() - timeStart, timeSpentInActionApplicabilityComputationThisIteration, timeSpentInSuccessorGenerationThisIteration, invocationsOfTreePolicyInThisIteration, invocationsOfDefaultPolicyInThisIteration, timeSpentInTreePolicyQueriesThisIteration, timeSpentInTreePolicyUpdatesThisIteration, timeSpentInDefaultPolicyThisIteration);
                    MCTSIterationCompletedEvent mCTSIterationCompletedEvent = event;
                    return mCTSIterationCompletedEvent;
                }
            }
            try {
                throw new IllegalStateException("Don't know what to do in state " + this.getState());
            }
            catch (ActionPredictionFailedException | ObjectEvaluationFailedException e) {
                throw new AlgorithmException("Could not create playout due to an exception! MCTS cannot deal with this in general. Please modify your MDP such that this kind of exceptions is resolved to some kind of score.", e);
            }
            catch (ExecutionException e) {
                throw new AlgorithmException("Observed error during timed computation.", (Throwable)e);
            }
            catch (InterruptedException e) {
                this.checkAndConductTermination();
                throw e;
            }
        }
        finally {
            this.logger.debug("Unregistering thread {}", (Object)Thread.currentThread());
            this.unregisterActiveThread();
        }
    }

    private void summarizeIteration(long timeForRolloutThisIteration, long timeSpentInActionApplicability, long timeSpentInSuccessorGenerationThisIteration, int numInvocationsOfTP, int numInvocationsOfDP, long timeSpentInTreePolicyQueriesThisIteration, long timeSpentInTreePolicyUpdatesThisIteration, long timeSpentInDefaultPolicyThisIteration) {
        this.msSpentInRollouts = (int)((long)this.msSpentInRollouts + timeForRolloutThisIteration);
        this.msSpentInTreePolicyQueries = (int)((long)this.msSpentInTreePolicyQueries + timeSpentInTreePolicyQueriesThisIteration);
        this.msSpentInTreePolicyUpdates = (int)((long)this.msSpentInTreePolicyUpdates + timeSpentInTreePolicyUpdatesThisIteration);
        this.logger.info("Finished rollout in {}ms. Time for computing applicable actions was {}ms and for computing successors {}ms. Time for TP {} queries was {}ms, time to update TP {}ms, time for {} DP queries was {}ms. Currently used memory: {}MB.", new Object[]{timeForRolloutThisIteration, timeSpentInActionApplicability, timeSpentInSuccessorGenerationThisIteration, numInvocationsOfTP, timeSpentInTreePolicyQueriesThisIteration, timeSpentInTreePolicyUpdatesThisIteration, numInvocationsOfDP, timeSpentInDefaultPolicyThisIteration, (runtime.totalMemory() - runtime.freeMemory()) / 0x100000L});
    }

    private void tabooLastActionOfPath(ILabeledPath<N, A> path) {
        if (path.isPoint()) {
            throw new IllegalArgumentException("The path is a point, which has no first action to taboo.");
        }
        Object lastStatePriorToEnd = path.getParentOfHead();
        Object lastAction = path.getOutArc(lastStatePriorToEnd);
        this.tabooActions.computeIfAbsent(lastStatePriorToEnd, n -> new HashSet()).add(lastAction);
        this.logger.debug("Adding action {} to taboo list of state {}", lastAction, lastStatePriorToEnd);
    }

    public int getNumberOfRealizedPlayouts() {
        return this.iterations;
    }

    public IPathUpdatablePolicy<N, A, Double> getTreePolicy() {
        return this.treePolicy;
    }

    public IPolicy<N, A> call() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        while (this.hasNext()) {
            this.nextWithException();
        }
        return this.treePolicy;
    }

    public void enforcePrefixPathOnAllRollouts(ILabeledPath<N, A> path) {
        if (!path.getRoot().equals(this.mdp.getInitState())) {
            throw new IllegalArgumentException("Illegal prefix, since root does not coincide with algorithm root. Proposed root is: " + path.getRoot());
        }
        this.enforcedPrefixPath = path;
        Object last = null;
        for (Object node : path.getNodes()) {
            if (last != null) {
                this.tpReadyStates.remove(last);
                this.tpReadyStates.add(node);
            }
            last = node;
        }
        throw new UnsupportedOperationException("Currently, enforced prefixes are ignored!");
    }

    public ILabeledPath<N, A> getEnforcedPrefixPath() {
        return this.enforcedPrefixPath.getUnmodifiableAccessor();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
        super.setLoggerName(name + ".abstract");
        if (this.mdp instanceof ILoggingCustomizable) {
            ((ILoggingCustomizable)this.mdp).setLoggerName(name + ".mdp");
        }
        if (this.treePolicy instanceof ILoggingCustomizable) {
            this.logger.info("Setting logger of tree policy to {}.treepolicy", (Object)name);
            ((ILoggingCustomizable)this.treePolicy).setLoggerName(name + ".tp");
        } else {
            this.logger.info("Not setting logger of tree policy, because {} is not customizable.", (Object)this.treePolicy.getClass().getName());
        }
        if (this.defaultPolicy instanceof ILoggingCustomizable) {
            this.logger.info("Setting logger of default policy to {}.defaultpolicy", (Object)name);
            ((ILoggingCustomizable)this.defaultPolicy).setLoggerName(name + ".dp");
        } else {
            this.logger.info("Not setting logger of default policy, because {} is not customizable.", (Object)this.defaultPolicy.getClass().getName());
        }
        this.utils.setLoggerName(name + ".utils");
    }

    public boolean hasTreePolicyReachedLeafs() {
        throw new UnsupportedOperationException("Currently not implemented.");
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public int getNumberOfNodesInMemory() {
        return this.tpReadyStates.size();
    }

    public int getMsSpentInRollouts() {
        return this.msSpentInRollouts;
    }

    public int getMsSpentInTreePolicyQueries() {
        return this.msSpentInTreePolicyQueries;
    }

    public int getMsSpentInTreePolicyUpdates() {
        return this.msSpentInTreePolicyUpdates;
    }

    public boolean isTabooExhaustedNodes() {
        return this.tabooExhaustedNodes;
    }
}

