/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.standard.mcts;

import ai.libs.jaicore.basic.algorithm.AlgorithmFinishedEvent;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.GraphBasedMDP;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTS;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSFactory;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSIterationCompletedEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.EvaluatedSearchSolutionCandidateFoundEvent;
import ai.libs.jaicore.search.core.interfaces.AOptimalPathInORGraphSearch;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.probleminputs.IMDP;
import com.google.common.eventbus.Subscribe;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.api4.java.ai.graphsearch.problem.IPathSearchWithPathEvaluationsInput;
import org.api4.java.algorithm.IAlgorithm;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.attributedobjects.ScoredItem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MCTSPathSearch<I extends IPathSearchWithPathEvaluationsInput<N, A, Double>, N, A>
extends AOptimalPathInORGraphSearch<I, N, A, Double> {
    private Logger logger = LoggerFactory.getLogger(MCTSPathSearch.class);
    private final GraphBasedMDP<N, A> mdp;
    private final MCTS<N, A> mcts;
    private final Set<Integer> hashCodesOfReturnedPaths = new HashSet<Integer>();

    public MCTSPathSearch(I problem, MCTSFactory<N, A, ?> mctsFactory) {
        super(problem);
        this.mdp = new GraphBasedMDP(problem);
        this.mcts = (MCTS)mctsFactory.getAlgorithm(this.mdp);
        this.mcts.registerListener(new Object(){

            @Subscribe
            public void receiveMCTSEvent(IAlgorithmEvent e) {
                if (!(e instanceof AlgorithmInitializedEvent) && !(e instanceof AlgorithmFinishedEvent)) {
                    MCTSPathSearch.this.post(e);
                }
            }
        });
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        switch (this.getState()) {
            case CREATED: {
                IAlgorithmEvent mctsInitEvent;
                this.mdp.setLoggerName(this.getLoggerName() + ".mdp");
                while (!((mctsInitEvent = this.mcts.next()) instanceof AlgorithmInitializedEvent)) {
                }
                return this.activate();
            }
            case ACTIVE: {
                IAlgorithmEvent e;
                if (this.mcts.getState() != EAlgorithmState.ACTIVE) {
                    return this.terminate();
                }
                while (!((e = this.mcts.nextWithException()) instanceof AlgorithmFinishedEvent)) {
                    if (!(e instanceof MCTSIterationCompletedEvent)) continue;
                    MCTSIterationCompletedEvent ce = (MCTSIterationCompletedEvent)e;
                    double overallScore = SetUtil.sum(ce.getScores());
                    this.logger.info("Registered rollout with score {}. Updating best seen solution correspondingly.", (Object)overallScore);
                    EvaluatedSearchGraphPath path = new EvaluatedSearchGraphPath(ce.getRollout(), overallScore);
                    if (!this.getGoalTester().isGoal(path)) continue;
                    this.updateBestSeenSolution((ScoredItem)path);
                    int hashCode = path.hashCode();
                    if (this.hashCodesOfReturnedPaths.contains(hashCode)) {
                        this.logger.info("Skipping (and supressing) previously found solution with hash code {}", (Object)hashCode);
                        continue;
                    }
                    this.hashCodesOfReturnedPaths.add(hashCode);
                    EvaluatedSearchSolutionCandidateFoundEvent event = new EvaluatedSearchSolutionCandidateFoundEvent((IAlgorithm<?, ?>)this, path);
                    this.post((Object)event);
                    return event;
                }
                return this.terminate();
            }
        }
        throw new IllegalStateException();
    }

    public void setTimeout(Timeout to) {
        long toInSeconds = to.seconds();
        if (toInSeconds < 2L) {
            throw new IllegalArgumentException("Cannot run MCTS with a timeout of less than 2 seconds.");
        }
        super.setTimeout(to);
        this.mcts.setTimeout(new Timeout(to.seconds() - 1L, TimeUnit.SECONDS));
    }

    public void cancel() {
        super.cancel();
        this.mcts.cancel();
    }

    @Override
    public void setLoggerName(String name) {
        super.setLoggerName(name + "._algorithm");
        this.logger = LoggerFactory.getLogger((String)name);
        this.mcts.setLoggerName(name + ".mcts");
    }

    @Override
    public String getLoggerName() {
        return this.logger.getName();
    }

    public IMDP<N, A, Double> getMdp() {
        return this.mdp;
    }

    public MCTS<N, A> getMcts() {
        return this.mcts;
    }

    public int getNumberOfNodesInMemory() {
        return this.mcts.getNumberOfNodesInMemory();
    }
}

