/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.comparison;

import ai.libs.jaicore.search.algorithms.mdp.mcts.comparison.IGammaFunction;

public class CombinedGammaFunction
implements IGammaFunction {
    private final IGammaFunction shortTermGamma;
    private final IGammaFunction longTermGamma;
    private static final double MID_WEIGHT = 0.8;
    private static final double ZERO_OFFSET = -5.0;
    private static final double LONG_TERM_BREAK = 0.1;
    private final double slope;

    public CombinedGammaFunction(IGammaFunction shortTermGamma, IGammaFunction longTermGamma) {
        this.shortTermGamma = shortTermGamma;
        this.longTermGamma = longTermGamma;
        double z = -1.0 * Math.log(0.25);
        this.slope = (-5.0 - z) / -0.5;
    }

    @Override
    public double getNodeGamma(int visits, double nodeProbability, double relativeDepth) {
        double longTermWeight = this.getLongTermWeightBasedOnProbability(nodeProbability);
        double vLongTermGamma = this.longTermGamma.getNodeGamma(visits, nodeProbability, relativeDepth);
        if (longTermWeight > 0.1 && vLongTermGamma == 0.0) {
            return 0.0;
        }
        double vShortTermGamma = this.shortTermGamma.getNodeGamma(visits, nodeProbability, relativeDepth);
        return vLongTermGamma * longTermWeight + vShortTermGamma * (1.0 - longTermWeight);
    }

    public double getLongTermWeightBasedOnProbability(double nodeProbability) {
        double xp = this.slope * nodeProbability + -5.0;
        return 1.0 / (1.0 + Math.exp(-1.0 * xp));
    }
}

