package tberg.murphy.lazystructpred;

import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.lazyopt.AdaGradL1Minimizer;
import tberg.murphy.lazyopt.AdaGradL2Minimizer;
import tberg.murphy.lazyopt.DifferentiableFunction;
import tberg.murphy.structpred.LossAugmentedLearner;
import tberg.murphy.structpred.LossAugmentedLinearModel;
import tberg.murphy.structpred.UpdateBundle;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/lazystructpred/PrimalSubgradientSVMLearner.class */
public class PrimalSubgradientSVMLearner<D> implements LossAugmentedLearner<D> {
    int numFeatures;
    double C;
    double delta;
    double stepSize;
    boolean L1reg;
    int batchSize;

    public PrimalSubgradientSVMLearner(double d, double d2, double d3, int i, boolean z) {
        this(d, d2, d3, i, z, 1);
    }

    public PrimalSubgradientSVMLearner(double d, double d2, double d3, int i, boolean z, int i2) {
        this.C = d;
        this.delta = d2;
        this.stepSize = d3;
        this.numFeatures = i;
        this.L1reg = z;
        this.batchSize = i2;
    }

    @Override // tberg.murphy.structpred.LossAugmentedLearner
    public CounterInterface<Integer> train(CounterInterface<Integer> counterInterface, final LossAugmentedLinearModel<D> lossAugmentedLinearModel, List<D> list, int i) {
        ArrayList arrayList = new ArrayList();
        int ceil = (int) Math.ceil(list.size() / this.batchSize);
        for (int i2 = 0; i2 < ceil; i2++) {
            final List<D> subList = list.subList(i2 * this.batchSize, Math.min(list.size(), (i2 + 1) * this.batchSize));
            arrayList.add(new DifferentiableFunction() { // from class: tberg.murphy.lazystructpred.PrimalSubgradientSVMLearner.1
                @Override // tberg.murphy.lazyopt.DifferentiableFunction
                public Pair<Double, CounterInterface<Integer>> calculate(CounterInterface<Integer> counterInterface2) {
                    lossAugmentedLinearModel.setWeights(counterInterface2);
                    List<UpdateBundle> lossAugmentedUpdateBundleBatch = lossAugmentedLinearModel.getLossAugmentedUpdateBundleBatch(subList, 1.0d);
                    double d = 0.0d;
                    IntCounter intCounter = new IntCounter();
                    for (UpdateBundle updateBundle : lossAugmentedUpdateBundleBatch) {
                        IntCounter intCounter2 = new IntCounter();
                        intCounter2.incrementAll(updateBundle.gold, -1.0d);
                        intCounter2.incrementAll(updateBundle.guess, 1.0d);
                        float f = 0.0f;
                        Iterator<Map.Entry<Integer, Double>> it = intCounter2.entries().iterator();
                        while (it.hasNext()) {
                            f = (float) (f + (r0.getValue().floatValue() * counterInterface2.getCount(Integer.valueOf(it.next().getKey().intValue()))));
                        }
                        double d2 = updateBundle.loss + f;
                        if (d2 > Cropper.VERT_GROW_RATIO) {
                            d += d2;
                            intCounter.incrementAll((CounterInterface) intCounter2);
                        }
                    }
                    return Pair.makePair(Double.valueOf(d), intCounter);
                }
            });
        }
        return (this.L1reg ? new AdaGradL1Minimizer(this.stepSize, this.delta, this.C, i) : new AdaGradL2Minimizer(this.stepSize, this.delta, this.C, i)).minimize(arrayList, IntCounter.convertToFloatArray(counterInterface, this.numFeatures), true, null);
    }
}
