package net.haesleinhuepf.clij2.assistant.optimize;

import java.util.Arrays;
import net.haesleinhuepf.clij2.assistant.utilities.Logger;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultiStartMultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;
import org.apache.commons.math3.random.GaussianRandomGenerator;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.UncorrelatedRandomVectorGenerator;

/* loaded from: input_file:net/haesleinhuepf/clij2/assistant/optimize/GradientDescentOptimizer.class */
public class GradientDescentOptimizer implements Optimizer {
    int iterations;

    public GradientDescentOptimizer() {
        this.iterations = 6;
    }

    public GradientDescentOptimizer(int i) {
        this.iterations = 6;
        this.iterations = i;
    }

    @Override // net.haesleinhuepf.clij2.assistant.optimize.Optimizer
    public double[] optimize(double[] dArr, Workflow workflow, int[] iArr, MultivariateFunction multivariateFunction, Logger logger) {
        NonLinearConjugateGradientOptimizer nonLinearConjugateGradientOptimizer = new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE, new SimpleValueChecker(1.0E-10d, 1.0E-10d));
        logger.log("Start:        " + Arrays.toString(dArr) + "\t");
        for (int i = 0; i < this.iterations; i++) {
            double[] range = OptimizationUtilities.range(dArr.length, workflow.getNumericParameterNames(), iArr, Math.pow(2.0d, ((this.iterations / 2) - i) - 1));
            System.out.println("Stddevs: " + Arrays.toString(range));
            PointValuePair pointValuePair = (PointValuePair) new MultiStartMultivariateOptimizer(nonLinearConjugateGradientOptimizer, 10, new UncorrelatedRandomVectorGenerator(dArr, range, new GaussianRandomGenerator(new JDKRandomGenerator()))).optimize(new OptimizationData[]{new MaxEval(1000), new ObjectiveFunction(multivariateFunction), new ObjectiveFunctionGradient(new GradientOfMultivariateFunction(multivariateFunction, range)), GoalType.MINIMIZE, new InitialGuess(dArr)});
            dArr = (double[]) pointValuePair.getKey();
            logger.log("Intermediate: " + Arrays.toString(dArr) + "\t f = " + pointValuePair.getValue());
        }
        logger.log("Final:        " + Arrays.toString(dArr) + "\t");
        return dArr;
    }
}
