/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.thompson;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.thompson.DNGBeliefUpdateEvent;
import ai.libs.jaicore.search.algorithms.mdp.mcts.thompson.DNGQSampleEvent;
import com.google.common.eventbus.EventBus;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DNGPolicy<N, A>
implements IPathUpdatablePolicy<N, A, Double>,
ILoggingCustomizable,
IRelaxedEventEmitter {
    private Logger logger = LoggerFactory.getLogger(DNGPolicy.class);
    private EventBus eventBus = new EventBus();
    private final boolean maximize;
    private final double initLambda;
    private static final double INIT_ALPHA = 1.0;
    private final double initBeta;
    private static final double INIT_MU = 0.5;
    private final Map<N, Double> alpha = new HashMap<N, Double>();
    private final Map<N, Double> beta = new HashMap<N, Double>();
    private final Map<N, Double> mu = new HashMap<N, Double>();
    private final Map<N, Double> lambda = new HashMap<N, Double>();
    private final Map<N, Map<A, Map<N, Integer>>> rho = new HashMap<N, Map<A, Map<N, Integer>>>();
    private final double gammaMDP;
    private final Predicate<N> terminalStatePredicate;
    private final Map<N, Map<A, Double>> rewardsMDP = new HashMap<N, Map<A, Double>>();
    private final double varianceFactor;
    private boolean sampling = true;

    public DNGPolicy(double gammaMDP, Predicate<N> terminalStatePredicate, double varianceFactor, double lambda, boolean maximize) {
        this.gammaMDP = gammaMDP;
        this.terminalStatePredicate = terminalStatePredicate;
        this.varianceFactor = varianceFactor;
        this.initLambda = lambda;
        this.initBeta = 1.0 / this.initLambda;
        this.maximize = maximize;
    }

    public boolean isSampling() {
        return this.sampling;
    }

    public void setSampling(boolean sampling) {
        this.sampling = sampling;
    }

    @Override
    public A getAction(N node, Collection<A> actionsWithSuccessors) throws ActionPredictionFailedException, InterruptedException {
        return this.sampleWithThompson(node, actionsWithSuccessors);
    }

    public A sampleWithThompson(N state, Collection<A> actions) throws InterruptedException {
        A bestAction = null;
        this.logger.info("Determining best action for state {}", state);
        double bestScore = (double)(this.maximize ? -1 : 1) * Double.MAX_VALUE;
        for (A action : actions) {
            double score = this.getQValue(state, action);
            this.logger.debug("Score for action {} is {}", action, (Object)score);
            this.eventBus.post(new DNGQSampleEvent<N, A>(null, state, action, score));
            if (bestAction != null && !(score < bestScore) && (!this.maximize || !(score > bestScore))) continue;
            bestAction = action;
            bestScore = score;
            this.logger.debug("Considering this as the new best action.");
        }
        Objects.requireNonNull(bestAction, "Best action cannot be null if there were " + actions.size() + " options!");
        this.logger.info("Recommending action {}", bestAction);
        return bestAction;
    }

    public double getQValue(N state, A action) throws InterruptedException {
        Map<N, Integer> rhoForThisPair = this.rho.get(state).get(action);
        if (rhoForThisPair == null) {
            throw new IllegalStateException("Have no rho vector for state/action pair " + state + "/" + action);
        }
        ArrayList<N> possibleSuccessors = new ArrayList<N>(rhoForThisPair.keySet());
        int numSuccessors = possibleSuccessors.size();
        if (rhoForThisPair.size() < numSuccessors) {
            throw new IllegalStateException("The rho vector for state/action pair " + state + "/" + action + " is incomplete and only has " + rhoForThisPair.size() + " instead of " + numSuccessors + " entries.");
        }
        double r = 0.0;
        this.logger.debug("Now determining q-value of action {}. Sampling: {}", action, (Object)this.sampling);
        if (this.sampling) {
            int i;
            double[] gammaVector = new double[numSuccessors];
            double totalGammas = 0.0;
            for (i = 0; i < numSuccessors; ++i) {
                double gamma;
                Object succ = possibleSuccessors.get(i);
                gammaVector[i] = gamma = new GammaDistribution((double)rhoForThisPair.get(succ).intValue(), 1.0).sample();
                totalGammas += gamma;
            }
            if (totalGammas == 0.0) {
                throw new IllegalStateException("The gamma estimates must not sum up to 0!");
            }
            for (i = 0; i < numSuccessors; ++i) {
                r += gammaVector[i] / totalGammas * this.getValue(possibleSuccessors.get(i));
            }
        } else {
            double denominator = rhoForThisPair.values().stream().reduce((a, b) -> a + b).get().intValue();
            for (Object succ : possibleSuccessors) {
                r += (double)rhoForThisPair.get(succ).intValue() / denominator;
            }
        }
        double reward = this.rewardsMDP.get(state).get(action);
        double totalReward = reward + this.gammaMDP * r;
        this.logger.debug("Considering a reward of {} + {} * {} = {}", new Object[]{reward, this.gammaMDP, r, totalReward});
        return totalReward;
    }

    public Pair<Double, Double> sampleWithNormalGamma(N state) {
        double tau = new GammaDistribution(this.alpha.get(state).doubleValue(), this.beta.get(state).doubleValue()).sample();
        double std = 1.0 / (this.lambda.get(state) * tau);
        double muNew = std > 0.0 ? new NormalDistribution(this.mu.get(state).doubleValue(), std).sample() : this.mu.get(state).doubleValue();
        return new Pair((Object)muNew, (Object)tau);
    }

    public double getValue(N state) throws InterruptedException {
        boolean isTerminal = this.terminalStatePredicate.test(state);
        if (Thread.interrupted()) {
            throw new InterruptedException();
        }
        if (isTerminal) {
            this.logger.debug("Returning value of 0 for terminal state {}", state);
            return 0.0;
        }
        if (this.sampling) {
            Pair<Double, Double> meanAndVariance = this.sampleWithNormalGamma(state);
            double val = (Double)meanAndVariance.getX() - this.varianceFactor * Math.sqrt((Double)meanAndVariance.getY());
            this.logger.debug("Returning sampled value of {}", (Object)val);
            return val;
        }
        double val = this.mu.get(state);
        this.logger.debug("Returning fixed value of {}", (Object)val);
        return val;
    }

    @Override
    public void updatePath(ILabeledPath<N, A> path, List<Double> scores) {
        List nodes = path.getNodes();
        List actions = path.getArcs();
        int l = path.getNumberOfNodes();
        this.logger.info("Updating path with scores {}", scores);
        double accumulatedScores = 0.0;
        for (int i = l - 2; i >= 0; --i) {
            Object node = nodes.get(i);
            Object action = actions.get(i);
            double rewardOfThisAction = scores.get(i) != null ? scores.get(i) : Double.NaN;
            this.rewardsMDP.computeIfAbsent(node, n -> new HashMap()).putIfAbsent(action, rewardOfThisAction);
            accumulatedScores = rewardOfThisAction + this.gammaMDP * accumulatedScores;
            this.logger.debug("Updating statistics for {}-th node with accumulated score {}. State here is: {}", new Object[]{i, accumulatedScores, node});
            if (!this.lambda.containsKey(node)) {
                this.lambda.put(node, this.initLambda);
                this.mu.put(node, 0.5);
                this.alpha.put(node, 1.0);
                this.beta.put(node, this.initBeta);
                Object succNode = nodes.get(i + 1);
                HashMap rhoForNodeActionPair = new HashMap();
                rhoForNodeActionPair.put(succNode, 1);
                HashMap mapForAction = new HashMap();
                mapForAction.put(action, rhoForNodeActionPair);
                this.rho.put(node, mapForAction);
                continue;
            }
            double lambdaOfN = this.lambda.get(node);
            double muOfN = this.mu.get(node);
            this.alpha.put(node, this.alpha.get(node) + 0.5);
            this.beta.put(node, this.beta.get(node) + lambdaOfN * Math.pow(accumulatedScores - muOfN, 2.0) / (lambdaOfN + 1.0) / 2.0);
            this.mu.put(node, (muOfN * lambdaOfN + accumulatedScores) / (lambdaOfN + 1.0));
            this.lambda.put(node, lambdaOfN + 1.0);
            Object succNode = nodes.get(i + 1);
            this.rho.get(node).computeIfAbsent(action, a -> new HashMap()).put(succNode, this.rho.get(node).get(action).computeIfAbsent(succNode, n -> 0) + 1);
            this.eventBus.post(new DNGBeliefUpdateEvent(null, node, this.mu.get(node), this.alpha.get(node), this.beta.get(node), this.lambda.get(node)));
        }
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
        this.logger.info("Logger is now {}", (Object)name);
    }

    public void registerListener(Object listener) {
        this.eventBus.register(listener);
    }
}

