package tberg.murphy.opt;

import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate;
import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import tberg.murphy.arrays.a;
import tberg.murphy.opt.Minimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/opt/AdaGradL1Minimizer.class */
public class AdaGradL1Minimizer implements OnlineMinimizer {
    double eta;
    double delta;
    double regConstant;
    int epochs;

    public AdaGradL1Minimizer(double d, double d2, double d3, int i) {
        this.eta = d;
        this.delta = d2;
        this.regConstant = d3;
        this.epochs = i;
    }

    @Override // tberg.murphy.opt.OnlineMinimizer
    public double[] minimize(List<DifferentiableFunction> list, double[] dArr, boolean z, Minimizer.Callback callback) {
        Random random = new Random(0L);
        double[] copy = a.copy(dArr);
        double[] dArr2 = new double[copy.length];
        a.addi(dArr2, this.delta);
        double d = this.eta * this.regConstant;
        for (int i = 0; i < this.epochs; i++) {
            double d2 = 0.0d;
            double[] dArr3 = new double[copy.length];
            for (int i2 : a.shuffle(a.enumerate(0, list.size()), random)) {
                Pair<Double, double[]> calculate = list.get(i2).calculate(copy);
                d2 += calculate.getFirst().doubleValue();
                double[] second = calculate.getSecond();
                a.combi(dArr3, 1.0d, second, 1.0d);
                a.combi(dArr2, 1.0d, a.sqr(second), 1.0d);
                for (int i3 = 0; i3 < copy.length; i3++) {
                    double sqrt = Math.sqrt(dArr2[i3]);
                    double d3 = copy[i3] - ((this.eta / sqrt) * second[i3]);
                    double abs = Math.abs(d3) - (d / sqrt);
                    if (abs > Cropper.VERT_GROW_RATIO) {
                        copy[i3] = (d3 > Cropper.VERT_GROW_RATIO ? 1.0d : -1.0d) * abs;
                    } else {
                        copy[i3] = 0.0d;
                    }
                }
            }
            if (z) {
                System.out.println(String.format("[AdaGradMinimizer.minimize] Epoch %d ended with value %.6f", Integer.valueOf(i), Double.valueOf(d2 + (this.regConstant * a.sum(a.abs(copy))))));
            }
            if (callback != null) {
                callback.callback(copy, i, d2, dArr3);
            }
        }
        return copy;
    }

    public static void main(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new DifferentiableFunction() { // from class: tberg.murphy.opt.AdaGradL1Minimizer.1
            @Override // tberg.murphy.opt.DifferentiableFunction
            public Pair<Double, double[]> calculate(double[] dArr) {
                return Pair.makePair(Double.valueOf((-a.sum(dArr)) + (2.0d * a.innerProd(dArr, dArr))), a.comb(a.scale(a.onesDouble(dArr.length), -1.0d), 1.0d, a.scale(dArr, 4.0d), 1.0d));
            }
        });
        arrayList.add(new DifferentiableFunction() { // from class: tberg.murphy.opt.AdaGradL1Minimizer.2
            @Override // tberg.murphy.opt.DifferentiableFunction
            public Pair<Double, double[]> calculate(double[] dArr) {
                return Pair.makePair(Double.valueOf((2.0d * a.sum(dArr)) + (2.0d * a.innerProd(dArr, dArr))), a.comb(a.scale(a.onesDouble(dArr.length), 2.0d), 1.0d, a.scale(dArr, 4.0d), 1.0d));
            }
        });
        new AdaGradL1Minimizer(0.1d, 0.01d, 0.1d, CharacterTemplate.INIT_LBFGS_ITERS).minimize(arrayList, a.onesDouble(10), true, null);
    }
}
