package tberg.murphy.lazystructpred;

import java.util.ArrayList;
import java.util.List;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.lazyopt.AdaGradL1Minimizer;
import tberg.murphy.lazyopt.AdaGradL2Minimizer;
import tberg.murphy.lazyopt.DifferentiableFunction;
import tberg.murphy.structpred.SummingLossAugmentedLearner;
import tberg.murphy.structpred.SummingLossAugmentedLinearModel;
import tberg.murphy.structpred.UpdateBundle;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/lazystructpred/SubgradientSoftMaxMarginLearner.class */
public class SubgradientSoftMaxMarginLearner<D> implements SummingLossAugmentedLearner<D> {
    int numFeatures;
    double C;
    double delta;
    double stepSize;
    boolean L1reg;
    int batchSize;

    public SubgradientSoftMaxMarginLearner(double d, double d2, double d3, int i, boolean z) {
        this(d, d2, d3, i, z, 1);
    }

    public SubgradientSoftMaxMarginLearner(double d, double d2, double d3, int i, boolean z, int i2) {
        this.C = d;
        this.delta = d2;
        this.stepSize = d3;
        this.numFeatures = i;
        this.L1reg = z;
        this.batchSize = i2;
    }

    @Override // tberg.murphy.structpred.SummingLossAugmentedLearner
    public CounterInterface<Integer> train(CounterInterface<Integer> counterInterface, final SummingLossAugmentedLinearModel<D> summingLossAugmentedLinearModel, List<D> list, int i) {
        ArrayList arrayList = new ArrayList();
        int ceil = (int) Math.ceil(list.size() / this.batchSize);
        for (int i2 = 0; i2 < ceil; i2++) {
            final List<D> subList = list.subList(i2 * this.batchSize, Math.min(list.size(), (i2 + 1) * this.batchSize));
            arrayList.add(new DifferentiableFunction() { // from class: tberg.murphy.lazystructpred.SubgradientSoftMaxMarginLearner.1
                @Override // tberg.murphy.lazyopt.DifferentiableFunction
                public Pair<Double, CounterInterface<Integer>> calculate(CounterInterface<Integer> counterInterface2) {
                    summingLossAugmentedLinearModel.setWeights(counterInterface2);
                    List<UpdateBundle> expectedLossAugmentedUpdateBundleBatch = summingLossAugmentedLinearModel.getExpectedLossAugmentedUpdateBundleBatch(subList, 1.0d);
                    double d = 0.0d;
                    IntCounter intCounter = new IntCounter();
                    for (UpdateBundle updateBundle : expectedLossAugmentedUpdateBundleBatch) {
                        intCounter.incrementAll(updateBundle.gold, -1.0d);
                        intCounter.incrementAll(updateBundle.guess, 1.0d);
                        d += (-updateBundle.gold.dotProduct(counterInterface2)) + updateBundle.loss;
                    }
                    return Pair.makePair(Double.valueOf(d), intCounter);
                }
            });
        }
        return (this.L1reg ? new AdaGradL1Minimizer(this.stepSize, this.delta, this.C, i) : new AdaGradL2Minimizer(this.stepSize, this.delta, this.C, i)).minimize(arrayList, IntCounter.convertToFloatArray(counterInterface, this.numFeatures), true, null);
    }
}
