package tberg.murphy.floatopt;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import tberg.murphy.arrays.a;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.floatopt.OnlineMinimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatopt/AdaGradL2Minimizer.class */
public class AdaGradL2Minimizer implements OnlineMinimizer {
    double eta;
    double delta;
    double regConstant;
    int epochs;

    public AdaGradL2Minimizer(double d, double d2, double d3, int i) {
        this.eta = d;
        this.delta = d2;
        this.regConstant = d3;
        this.epochs = i;
    }

    @Override // tberg.murphy.floatopt.OnlineMinimizer
    public float[] minimize(List<DifferentiableFunction> list, float[] fArr, boolean z, OnlineMinimizer.Callback callback) {
        Random random = new Random(0L);
        float[] copy = a.copy(fArr);
        float[] fArr2 = new float[copy.length];
        a.addi(fArr2, (float) this.delta);
        float f = (float) (this.eta * this.regConstant);
        for (int i = 0; i < this.epochs; i++) {
            float f2 = 0.0f;
            for (int i2 : a.shuffle(a.enumerate(0, list.size()), random)) {
                Pair<Double, CounterInterface<Integer>> calculate = list.get(i2).calculate(copy);
                f2 = (float) (f2 + calculate.getFirst().doubleValue());
                CounterInterface<Integer> second = calculate.getSecond();
                for (Map.Entry<Integer, Double> entry : second.entries()) {
                    int intValue = entry.getKey().intValue();
                    double doubleValue = entry.getValue().doubleValue();
                    fArr2[intValue] = (float) (fArr2[intValue] + (doubleValue * doubleValue));
                }
                for (int i3 = 0; i3 < copy.length; i3++) {
                    float sqrt = (float) Math.sqrt(fArr2[i3]);
                    copy[i3] = (sqrt * copy[i3]) / (f + sqrt);
                }
                Iterator<Map.Entry<Integer, Double>> it = second.entries().iterator();
                while (it.hasNext()) {
                    copy[it.next().getKey().intValue()] = (float) (copy[r0] + (((-this.eta) * r0.getValue().floatValue()) / (f + ((float) Math.sqrt(fArr2[r0])))));
                }
            }
            if (z) {
                System.out.println(String.format("[AdaGradMinimizer.minimize] Epoch %d ended with value %.6f", Integer.valueOf(i), Double.valueOf(f2 + (this.regConstant * a.innerProd(copy, copy)))));
            }
            if (callback != null) {
                callback.callback(copy, i, f2 + (this.regConstant * a.innerProd(copy, copy)));
            }
        }
        return copy;
    }
}
