package tberg.murphy.opt;

import java.util.List;
import java.util.Random;
import tberg.murphy.arrays.a;
import tberg.murphy.opt.Minimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/opt/SGDMinimizer.class */
public class SGDMinimizer implements OnlineMinimizer {
    double startLearningRate;
    double endLearningRate;
    int epochs;

    public SGDMinimizer(double d, double d2, int i) {
        this.startLearningRate = d;
        this.endLearningRate = d2;
        this.epochs = i;
    }

    @Override // tberg.murphy.opt.OnlineMinimizer
    public double[] minimize(List<DifferentiableFunction> list, double[] dArr, boolean z, Minimizer.Callback callback) {
        Random random = new Random(0L);
        double[] copy = a.copy(dArr);
        double d = 0.0d;
        for (int i = 0; i < this.epochs; i++) {
            double d2 = 0.0d;
            double[] dArr2 = new double[copy.length];
            for (int i2 : a.shuffle(a.enumerate(0, list.size()), random)) {
                Pair<Double, double[]> calculate = list.get(i2).calculate(copy);
                d2 += calculate.getFirst().doubleValue();
                double[] second = calculate.getSecond();
                a.combi(copy, 1.0d, second, -(this.startLearningRate + ((d / (this.epochs * list.size())) * (this.endLearningRate - this.startLearningRate))));
                a.combi(dArr2, 1.0d, second, 1.0d);
                d += 1.0d;
            }
            if (z) {
                System.out.println(String.format("[SGDMinimizer.minimize] Epoch %d ended with value %.6f", Integer.valueOf(i), Double.valueOf(d2)));
            }
            if (callback != null) {
                callback.callback(copy, i, d2, dArr2);
            }
        }
        return copy;
    }
}
