package ai.libs.jaicore.ml.core.optimizing.graddesc;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.optimizing.IGradientBasedOptimizer;
import ai.libs.jaicore.ml.core.optimizing.IGradientDescendableFunction;
import ai.libs.jaicore.ml.core.optimizing.IGradientFunction;
import java.util.Map;
import org.aeonbits.owner.ConfigFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/optimizing/graddesc/GradientDescentOptimizer.class */
public class GradientDescentOptimizer implements IGradientBasedOptimizer {
    private double learningRate;
    private final double gradientThreshold;
    private final int maxIterations;
    private static final Logger log = LoggerFactory.getLogger(GradientDescentOptimizer.class);

    public GradientDescentOptimizer(GradientDescentOptimizerConfig gradientDescentOptimizerConfig) {
        this.learningRate = gradientDescentOptimizerConfig.learningRate();
        this.gradientThreshold = gradientDescentOptimizerConfig.gradientThreshold();
        this.maxIterations = gradientDescentOptimizerConfig.maxIterations();
    }

    public GradientDescentOptimizer() {
        this(ConfigFactory.create(GradientDescentOptimizerConfig.class, new Map[0]));
    }

    @Override // ai.libs.jaicore.ml.core.optimizing.IGradientBasedOptimizer
    public Vector optimize(IGradientDescendableFunction iGradientDescendableFunction, IGradientFunction iGradientFunction, Vector vector) {
        int i = 0;
        do {
            Vector apply = iGradientFunction.apply(vector);
            i++;
            updatePredictions(vector, apply);
            log.warn("iteration {}:\n weights \t{} \n gradients \t{}", new Object[]{Integer.valueOf(i), vector, apply});
            if (allGradientsAreBelowThreshold(apply)) {
                break;
            }
        } while (i < this.maxIterations);
        log.warn("Gradient descent based optimization took {} iterations.", Integer.valueOf(i));
        return vector;
    }

    private boolean allGradientsAreBelowThreshold(Vector vector) {
        return vector.stream().allMatch(d -> {
            return Math.abs(d) < this.gradientThreshold || !Double.isFinite(d);
        });
    }

    private void updatePredictions(Vector vector, Vector vector2) {
        for (int i = 0; i < vector.length(); i++) {
            double value = vector.getValue(i);
            double value2 = vector2.getValue(i);
            if (Math.abs(value2) >= this.gradientThreshold) {
                vector.setValue(i, value + (value2 * (-1.0d) * this.learningRate));
            }
        }
    }
}
