/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.comparison.preferencekernel.bootstrapping;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.graphvisualizer.events.graph.NodePropertyChangedEvent;
import ai.libs.jaicore.search.algorithms.mdp.mcts.comparison.IPreferenceKernel;
import ai.libs.jaicore.search.algorithms.mdp.mcts.comparison.preferencekernel.bootstrapping.IBootstrapConfigurator;
import ai.libs.jaicore.search.algorithms.mdp.mcts.comparison.preferencekernel.bootstrapping.IBootstrappingParameterComputer;
import com.google.common.eventbus.EventBus;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BootstrappingPreferenceKernel<N, A>
implements IPreferenceKernel<N, A>,
ILoggingCustomizable,
IRelaxedEventEmitter {
    private static final int MAXTIME_WARN_CREATERANKINGS = 1;
    private Logger logger = LoggerFactory.getLogger(BootstrappingPreferenceKernel.class);
    private final EventBus eventBus = new EventBus();
    private boolean hasListeners = false;
    private final Set<N> activeNodes = new HashSet<N>();
    private final Map<N, Map<A, DoubleList>> observations = new HashMap<N, Map<A, DoubleList>>();
    private final Map<N, Map<A, Double>> bestObservationForAction = new HashMap<N, Map<A, Double>>();
    private final IBootstrappingParameterComputer bootstrapParameterComputer;
    private final IBootstrapConfigurator bootstrapConfigurator;
    private final int maxNumSamplesInHistory;
    private final Random random;
    private final Map<N, List<List<A>>> rankingsForNodes = new HashMap<N, List<List<A>>>();
    private final int minSamplesToCreateRankings = 1;
    private int erasedObservationsInTotal = 0;

    public BootstrappingPreferenceKernel(IBootstrappingParameterComputer bootstrapParameterComputer, IBootstrapConfigurator bootstrapConfigurator, Random random, int minSamplesToCreateRankings, int maxNumSamplesInHistory) {
        this.bootstrapParameterComputer = bootstrapParameterComputer;
        this.bootstrapConfigurator = bootstrapConfigurator;
        this.random = random;
        this.maxNumSamplesInHistory = maxNumSamplesInHistory;
    }

    public BootstrappingPreferenceKernel(IBootstrappingParameterComputer bootstrapParameterComputer, IBootstrapConfigurator bootstrapConfigurator, int minSamplesToCreateRankings) {
        this(bootstrapParameterComputer, bootstrapConfigurator, new Random(0L), minSamplesToCreateRankings, 1000);
    }

    @Override
    public void signalNewScore(ILabeledPath<N, A> path, double newScore) {
        List nodes = path.getNodes();
        List arcs = path.getArcs();
        int l = nodes.size();
        for (int i = 0; i < l - 1; ++i) {
            Object node = nodes.get(i);
            Object arc = arcs.get(i);
            DoubleList list = this.observations.computeIfAbsent(node, n -> new HashMap()).computeIfAbsent(arc, a -> new DoubleArrayList());
            list.add(newScore);
            Map bestMap = this.bestObservationForAction.computeIfAbsent(node, n -> new HashMap());
            bestMap.put(arc, Math.min(newScore, bestMap.computeIfAbsent(arc, a -> Double.MAX_VALUE)));
            if (list.size() > this.maxNumSamplesInHistory) {
                list.removeDouble(0);
            }
            this.logger.debug("Updated observations for action {} in node {}. New list of observations is: {}", new Object[]{arc, node, list});
            if (this.activeNodes.contains(node)) continue;
            this.logger.info("The current node has not been marked active and hence, we abort the update procedure saving {} entries.", (Object)(l - i));
            return;
        }
    }

    public List<List<A>> drawNewRankingsForActions(N node, Collection<A> actions, IBootstrappingParameterComputer parameterComputer) {
        long start = System.currentTimeMillis();
        for (A action : actions) {
            if (this.observations.containsKey(node) && this.observations.get(node).containsKey(action)) continue;
            throw new IllegalArgumentException("No observations available for action " + action + ", cannot draw ranking.");
        }
        Map<A, DoubleList> observationsPerAction = this.observations.get(node);
        int numBootstraps = this.bootstrapConfigurator.getNumBootstraps(observationsPerAction);
        int numSamplesPerChildInEachBootstrap = this.bootstrapConfigurator.getBootstrapSizePerChild(observationsPerAction);
        this.logger.debug("Now creating {} bootstraps (rankings)", (Object)numBootstraps);
        int totalObservations = 0;
        ArrayList<List<A>> rankings = new ArrayList<List<A>>(numBootstraps);
        for (int bootstrap = 0; bootstrap < numBootstraps; ++bootstrap) {
            HashMap scorePerAction = new HashMap();
            totalObservations = 0;
            for (A action : actions) {
                DoubleList observedScoresForChild = observationsPerAction.get(action);
                totalObservations += observedScoresForChild.size();
                double bestObservation = this.bestObservationForAction.get(node).get(action);
                DescriptiveStatistics statsForThisChild = new DescriptiveStatistics();
                statsForThisChild.addValue(bestObservation);
                for (int sample = 0; sample < numSamplesPerChildInEachBootstrap - 1; ++sample) {
                    statsForThisChild.addValue(((Double)SetUtil.getRandomElement((Collection)observedScoresForChild, (Random)this.random)).doubleValue());
                }
                scorePerAction.put(action, parameterComputer.getParameter(statsForThisChild));
            }
            List ranking = actions.stream().sorted((a1, a2) -> Double.compare((Double)scorePerAction.get(a1), (Double)scorePerAction.get(a2))).collect(Collectors.toList());
            rankings.add(ranking);
        }
        long runtime = System.currentTimeMillis() - start;
        if (runtime > 1L) {
            this.logger.warn("Creating the {} rankings took {}ms for {} options and {} total observations, which is more than the allowed {}ms!", new Object[]{numBootstraps, runtime, actions.size(), totalObservations, 1});
        }
        return rankings;
    }

    @Override
    public List<List<A>> getRankingsForActions(N node, Collection<A> actions) {
        this.rankingsForNodes.put(node, this.drawNewRankingsForActions(node, actions, this.bootstrapParameterComputer));
        return this.rankingsForNodes.get(node);
    }

    @Override
    public boolean canProduceReliableRankings(N node, Collection<A> actions) {
        if (!this.observations.containsKey(node)) {
            if (this.hasListeners) {
                this.eventBus.post((Object)new NodePropertyChangedEvent(null, node, "plkernelstatus", (Object)0.0));
            }
            this.logger.info("No observations for node yet, not allowing to produce rankings.");
            return false;
        }
        Map<A, DoubleList> scoresPerAction = this.observations.get(node);
        for (A action : actions) {
            if (scoresPerAction.containsKey(action) && scoresPerAction.get(action).size() >= this.minSamplesToCreateRankings) continue;
            this.logger.info("Refusing production of rankings, because are less than {} observations for action {}.", (Object)this.minSamplesToCreateRankings, action);
            if (this.hasListeners) {
                this.eventBus.post((Object)new NodePropertyChangedEvent(null, node, "plkernelstatus", (Object)0.0));
            }
            return false;
        }
        this.logger.debug("Enough examples. Allowing the construction of rankings.");
        if (this.hasListeners) {
            this.eventBus.post((Object)new NodePropertyChangedEvent(null, node, "plkernelstatus", (Object)1.0));
        }
        return true;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }

    @Override
    public void clearKnowledge(N node) {
        if (!this.observations.containsKey(node) || this.observations.get(node).isEmpty()) {
            return;
        }
        if (this.logger.isInfoEnabled()) {
            this.logger.info("Removing {} observations.", (Object)this.observations.get(node).values().stream().map(l -> l.size()).reduce((a, b) -> a + b).get());
        }
        this.erasedObservationsInTotal += this.observations.get(node).size();
        this.observations.remove(node);
        if (this.logger.isInfoEnabled() && this.rankingsForNodes.containsKey(node)) {
            this.logger.info("Removing {} rankings.", (Object)this.rankingsForNodes.get(node).size());
        }
        this.rankingsForNodes.remove(node);
    }

    public Map<A, DoubleList> getObservations(N node) {
        return this.observations.get(node);
    }

    public Set<N> getActiveNodes() {
        return this.activeNodes;
    }

    public void registerListener(Object listener) {
        this.eventBus.register(listener);
        this.hasListeners = true;
    }

    @Override
    public void signalNodeActiveness(N node) {
        this.activeNodes.add(node);
    }

    @Override
    public int getErasedObservationsInTotal() {
        return this.erasedObservationsInTotal;
    }

    @Override
    public A getMostImportantActionToObtainApplicability(N node, Collection<A> actions) {
        Map<A, DoubleList> obsForNode = this.observations.get(node);
        A leastTriedAction = null;
        int minAttempts = Integer.MAX_VALUE;
        for (A action : actions) {
            int attempts = obsForNode.containsKey(action) ? obsForNode.get(action).size() : 0;
            if (attempts >= minAttempts) continue;
            minAttempts = attempts;
            leastTriedAction = action;
        }
        return leastTriedAction;
    }
}

