/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.search.algorithms.mdp.mcts.comparison;

import ai.libs.jaicore.search.algorithms.mdp.mcts.comparison.IGammaFunction;
import java.util.function.DoubleFunction;

public class CosLinGammaFunction
implements IGammaFunction {
    private final double maxGamma;
    private final int visitsToReachOne;
    private final int initialMinThreshold;
    private final int absoluteMinThreshold;
    private final DoubleFunction<Double> exploitationShape = x -> {
        if (x < 0.0 || x > 1.0) {
            throw new IllegalArgumentException();
        }
        double val = 0.5 * (Math.cos(x * Math.PI) + 1.0);
        if (val > 1.0 || val < 0.0) {
            throw new IllegalStateException("shape range must be within unit interval!");
        }
        return 1.0 - val;
    };

    public CosLinGammaFunction(double maxGamma, int visitsToReachOne, int initialMinThreshold, int absoluteMinThreshold) {
        this.maxGamma = maxGamma;
        this.visitsToReachOne = visitsToReachOne;
        this.initialMinThreshold = initialMinThreshold;
        this.absoluteMinThreshold = absoluteMinThreshold;
    }

    public int getMinRequiredVisits(double relativeDepth) {
        double certaintyBound = 5.0;
        double maxRelativeDepthForMinMinThreshold = 0.8;
        double slope = 12.5;
        double factor = 1.0 - 1.0 / (1.0 + Math.exp(-1.0 * (slope * relativeDepth - 5.0)));
        double min = factor * (double)(this.initialMinThreshold - this.absoluteMinThreshold) + (double)this.absoluteMinThreshold;
        return (int)Math.round(min);
    }

    @Override
    public double getNodeGamma(int visits, double nodeProbability, double relativeDepth) {
        double g;
        int minThreshold = this.getMinRequiredVisits(relativeDepth);
        if (visits <= minThreshold) {
            return 0.0;
        }
        if (visits > this.visitsToReachOne) {
            g = Math.min(this.maxGamma, Math.pow((double)visits - (double)this.visitsToReachOne, 0.3333333333333333));
        } else {
            double scaledValue = (double)(visits - minThreshold) * 1.0 / (double)(this.visitsToReachOne - minThreshold);
            if (scaledValue < 0.0 || scaledValue > 1.0) {
                throw new IllegalStateException("Computed intermediate gamma value " + scaledValue);
            }
            g = this.exploitationShape.apply(scaledValue);
            if (g < 0.0 || g > 1.0) {
                throw new IllegalStateException();
            }
        }
        if (g < 0.0 || g > this.maxGamma) {
            throw new IllegalStateException();
        }
        return g;
    }
}

