package tberg.murphy.floatopt;

import java.util.List;
import java.util.Map;
import java.util.Random;
import tberg.murphy.arrays.a;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.floatopt.OnlineMinimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatopt/AdaGradL1Minimizer.class */
public class AdaGradL1Minimizer implements OnlineMinimizer {
    float eta;
    float delta;
    float regConstant;
    int epochs;

    public AdaGradL1Minimizer(double d, double d2, double d3, int i) {
        this.eta = (float) d;
        this.delta = (float) d2;
        this.regConstant = (float) d3;
        this.epochs = i;
    }

    @Override // tberg.murphy.floatopt.OnlineMinimizer
    public float[] minimize(List<DifferentiableFunction> list, float[] fArr, boolean z, OnlineMinimizer.Callback callback) {
        Random random = new Random(0L);
        float[] copy = a.copy(fArr);
        float[] fArr2 = new float[copy.length];
        a.addi(fArr2, this.delta);
        float f = this.regConstant * this.eta;
        for (int i = 0; i < this.epochs; i++) {
            float f2 = 0.0f;
            for (int i2 : a.shuffle(a.enumerate(0, list.size()), random)) {
                Pair<Double, CounterInterface<Integer>> calculate = list.get(i2).calculate(copy);
                f2 = (float) (f2 + calculate.getFirst().doubleValue());
                CounterInterface<Integer> second = calculate.getSecond();
                for (Map.Entry<Integer, Double> entry : second.entries()) {
                    int intValue = entry.getKey().intValue();
                    double doubleValue = entry.getValue().doubleValue();
                    fArr2[intValue] = (float) (fArr2[intValue] + (doubleValue * doubleValue));
                }
                for (Map.Entry<Integer, Double> entry2 : second.entries()) {
                    int intValue2 = entry2.getKey().intValue();
                    copy[intValue2] = copy[intValue2] + ((-(this.eta / ((float) Math.sqrt(fArr2[intValue2])))) * entry2.getValue().floatValue());
                }
                for (int i3 = 0; i3 < copy.length; i3++) {
                    float sqrt = (float) Math.sqrt(fArr2[i3]);
                    float f3 = copy[i3];
                    float abs = Math.abs(f3) - (f / sqrt);
                    if (abs > 0.0f) {
                        copy[i3] = (f3 > 0.0f ? 1.0f : -1.0f) * abs;
                    } else {
                        copy[i3] = 0.0f;
                    }
                }
            }
            if (z) {
                System.out.println(String.format("[AdaGradMinimizer.minimize] Epoch %d ended with value %.6f", Integer.valueOf(i), Float.valueOf(f2 + (this.regConstant * a.sum(a.abs(copy))))));
            }
            if (callback != null) {
                callback.callback(copy, i, f2 + (this.regConstant * a.sum(a.abs(copy))));
            }
        }
        return copy;
    }
}
