package tberg.murphy.structpred;

import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.ArrayList;
import java.util.List;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.structpred.NSlackSVMLearner;
import tberg.murphy.util.Maxer;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/structpred/OneSlackSVMLearner.class */
public class OneSlackSVMLearner<T> extends NSlackSVMLearner<T> {
    private List<List<UpdateBundle>> cache;

    public OneSlackSVMLearner(double d, double d2) {
        this(d, d2, new NSlackSVMLearner.SvmOpts());
    }

    public OneSlackSVMLearner(double d, double d2, NSlackSVMLearner.SvmOpts svmOpts) {
        super(d, d2, svmOpts);
        this.cache = new ArrayList();
    }

    @Override // tberg.murphy.structpred.NSlackSVMLearner
    public int reapConstraints(boolean z, LossAugmentedLinearModel<T> lossAugmentedLinearModel, List<T> list, CounterInterface<Integer> counterInterface, int i, int i2, int i3, int i4) {
        NSlackSVMLearner.SvmOpts svmOpts = this.opts;
        NSlackSVMLearner.SvmOpts.minDecodeToSmoTimeRatio = Cropper.VERT_GROW_RATIO;
        NSlackSVMLearner.SvmOpts svmOpts2 = this.opts;
        NSlackSVMLearner.SvmOpts.smoMiniBatch = false;
        if (z) {
            clearConstraints(i2);
            for (int i5 = 0; i5 < i2; i5++) {
                addConstraint(i5, new IntCounter(), Cropper.VERT_GROW_RATIO);
            }
        }
        int tryCache = tryCache(list, counterInterface, i, i3);
        if (tryCache > 0) {
            System.out.printf("Using cache\n", new Object[0]);
            return tryCache;
        }
        List<UpdateBundle> batchLossAugmentedDecode = batchLossAugmentedDecode(lossAugmentedLinearModel, list, counterInterface, 1.0d);
        IntCounter intCounter = new IntCounter();
        double d = 0.0d;
        for (int i6 = i; i6 < i + list.size(); i6++) {
            UpdateBundle updateBundle = batchLossAugmentedDecode.get(i6 - i);
            NSlackSVMLearner.SvmOpts svmOpts3 = this.opts;
            if (NSlackSVMLearner.SvmOpts.oneSlackCacheSize > 0) {
                List<UpdateBundle> list2 = this.cache.get(i6);
                list2.add(updateBundle);
                while (true) {
                    int size = list2.size();
                    NSlackSVMLearner.SvmOpts svmOpts4 = this.opts;
                    if (size > NSlackSVMLearner.SvmOpts.oneSlackCacheSize) {
                        list2.remove(0);
                    }
                }
            }
            IntCounter intCounter2 = new IntCounter();
            intCounter2.incrementAll(updateBundle.gold);
            intCounter2.incrementAll(updateBundle.guess, -1.0d);
            double d2 = updateBundle.loss;
            intCounter.incrementAll((CounterInterface) intCounter2);
            d += d2;
        }
        intCounter.scale(1.0d / list.size());
        return addConstraintIfNecessary(counterInterface, i3, intCounter, d / list.size());
    }

    private int tryCache(List<T> list, CounterInterface<Integer> counterInterface, int i, int i2) {
        boolean z;
        NSlackSVMLearner.SvmOpts svmOpts = this.opts;
        if (NSlackSVMLearner.SvmOpts.oneSlackCacheSize == 0) {
            return 0;
        }
        int i3 = 0;
        boolean z2 = false;
        while (true) {
            z = z2;
            if (i + list.size() <= this.cache.size()) {
                break;
            }
            this.cache.add(new ArrayList());
            z2 = true;
        }
        if (!z) {
            IntCounter intCounter = new IntCounter();
            double d = 0.0d;
            for (int i4 = i; i4 < i + list.size(); i4++) {
                List<UpdateBundle> list2 = this.cache.get(i4);
                int size = list2.size();
                NSlackSVMLearner.SvmOpts svmOpts2 = this.opts;
                if (size >= NSlackSVMLearner.SvmOpts.cacheWarmup) {
                    Maxer maxer = new Maxer();
                    for (UpdateBundle updateBundle : list2) {
                        maxer.observe(updateBundle, getDelta(updateBundle.guess, updateBundle.gold).dotProduct(counterInterface) + updateBundle.loss);
                    }
                    UpdateBundle updateBundle2 = (UpdateBundle) maxer.argMax();
                    intCounter.incrementAll(updateBundle2.gold);
                    intCounter.incrementAll(updateBundle2.guess, -1.0d);
                    d += updateBundle2.loss;
                }
            }
            intCounter.scale(1.0d / list.size());
            i3 = addConstraintIfNecessary(counterInterface, i2, intCounter, d / list.size());
        }
        return i3;
    }
}
