package tberg.murphy.lazyopt;

import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.List;
import java.util.Map;
import java.util.Random;
import tberg.murphy.arrays.a;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.lazyopt.OnlineMinimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/lazyopt/AdaGradL1Minimizer.class */
public class AdaGradL1Minimizer implements OnlineMinimizer {
    float eta;
    float delta;
    float regConstant;
    int epochs;
    float r;

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/lazyopt/AdaGradL1Minimizer$LazyAdaGradResult.class */
    private class LazyAdaGradResult implements CounterInterface<Integer> {
        int[] current;
        int[] lastUpdate;
        float[] x;
        float[] sqrGradSum;

        public LazyAdaGradResult(int[] iArr, int[] iArr2, float[] fArr, float[] fArr2) {
            this.current = null;
            this.lastUpdate = null;
            this.x = null;
            this.sqrGradSum = null;
            this.current = iArr;
            this.lastUpdate = iArr2;
            this.x = fArr;
            this.sqrGradSum = fArr2;
        }

        @Override // tberg.murphy.counter.CounterInterface
        public double dotProduct(CounterInterface<Integer> counterInterface) {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public Iterable<Map.Entry<Integer, Double>> entries() {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public double incrementCount(Integer num, double d) {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public <T extends Integer> void incrementAll(CounterInterface<T> counterInterface, double d) {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public <T extends Integer> void incrementAll(CounterInterface<T> counterInterface) {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public void scale(double d) {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public final double getCount(Integer num) {
            AdaGradL1Minimizer.flushShrinkageUpdates(num.intValue(), this.current, this.lastUpdate, this.x, this.sqrGradSum, AdaGradL1Minimizer.this.r);
            return this.x[num.intValue()];
        }

        @Override // tberg.murphy.counter.CounterInterface
        public void setCount(Integer num, double d) {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public double totalCount() {
            throw new Error("Method not implemented.");
        }

        @Override // tberg.murphy.counter.CounterInterface
        public final int size() {
            return this.x.length;
        }

        @Override // tberg.murphy.counter.CounterInterface
        public Iterable<Integer> keySet() {
            throw new Error("Method not implemented.");
        }
    }

    public AdaGradL1Minimizer(double d, double d2, double d3, int i) {
        this.eta = (float) d;
        this.delta = (float) d2;
        this.regConstant = (float) d3;
        this.epochs = i;
        this.r = (float) (d3 * d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final void flushShrinkageUpdates(int i, int[] iArr, int[] iArr2, float[] fArr, float[] fArr2, float f) {
        int i2 = iArr[0] - iArr2[i];
        if (i2 > 0) {
            float sqrt = (i2 * f) / ((float) Math.sqrt(fArr2[i]));
            float f2 = fArr[i];
            if (f2 < Cropper.VERT_GROW_RATIO && f2 < (-sqrt)) {
                fArr[i] = f2 + sqrt;
            } else if (f2 <= Cropper.VERT_GROW_RATIO || f2 <= sqrt) {
                fArr[i] = 0.0f;
            } else {
                fArr[i] = f2 - sqrt;
            }
            iArr2[i] = iArr[0];
        }
    }

    @Override // tberg.murphy.lazyopt.OnlineMinimizer
    public CounterInterface<Integer> minimize(List<DifferentiableFunction> list, float[] fArr, boolean z, OnlineMinimizer.Callback callback) {
        Random random = new Random(0L);
        int[] iArr = {0};
        float[] copy = a.copy(fArr);
        int[] iArr2 = new int[copy.length];
        float[] fArr2 = new float[copy.length];
        a.addi(fArr2, this.delta);
        LazyAdaGradResult lazyAdaGradResult = new LazyAdaGradResult(iArr, iArr2, copy, fArr2);
        for (int i = 0; i < this.epochs; i++) {
            float f = 0.0f;
            for (int i2 : a.shuffle(a.enumerate(0, list.size()), random)) {
                Pair<Double, CounterInterface<Integer>> calculate = list.get(i2).calculate(lazyAdaGradResult);
                f = (float) (f + calculate.getFirst().doubleValue());
                CounterInterface<Integer> second = calculate.getSecond();
                for (Map.Entry<Integer, Double> entry : second.entries()) {
                    int intValue = entry.getKey().intValue();
                    double doubleValue = entry.getValue().doubleValue();
                    fArr2[intValue] = (float) (fArr2[intValue] + (doubleValue * doubleValue));
                }
                for (Map.Entry<Integer, Double> entry2 : second.entries()) {
                    int intValue2 = entry2.getKey().intValue();
                    flushShrinkageUpdates(intValue2, iArr, iArr2, copy, fArr2, this.r);
                    copy[intValue2] = copy[intValue2] + ((-(this.eta / ((float) Math.sqrt(fArr2[intValue2])))) * entry2.getValue().floatValue());
                }
                iArr[0] = iArr[0] + 1;
            }
            if (z || callback != null) {
                float[] fArr3 = new float[fArr.length];
                for (int i3 = 0; i3 < fArr3.length; i3++) {
                    fArr3[i3] = (float) lazyAdaGradResult.getCount(Integer.valueOf(i3));
                }
                double sum = f + (this.regConstant * a.sum(a.abs(fArr3)));
                if (z) {
                    System.out.println(String.format("[AdaGradMinimizer.minimize] Epoch %d ended with value %.6f", Integer.valueOf(i), Double.valueOf(sum)));
                }
                if (callback != null) {
                    callback.callback(copy, i, sum);
                }
            }
        }
        double[] dArr = new double[fArr.length];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = lazyAdaGradResult.getCount(Integer.valueOf(i4));
        }
        return IntCounter.wrapArray(dArr, dArr.length);
    }
}
