package tberg.murphy.structpred;

import java.util.List;
import java.util.Map;
import tberg.murphy.arrays.a;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.opt.DifferentiableFunction;
import tberg.murphy.opt.LBFGSMinimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/structpred/LBFGSSoftMaxMarginLearner.class */
public class LBFGSSoftMaxMarginLearner<D> implements SummingLossAugmentedLearner<D> {
    int numFeatures;
    double C;
    double tolerance;

    public LBFGSSoftMaxMarginLearner(double d, double d2, int i) {
        this.C = d;
        this.numFeatures = i;
        this.tolerance = d2;
    }

    @Override // tberg.murphy.structpred.SummingLossAugmentedLearner
    public CounterInterface<Integer> train(CounterInterface<Integer> counterInterface, final SummingLossAugmentedLinearModel<D> summingLossAugmentedLinearModel, final List<D> list, int i) {
        DifferentiableFunction differentiableFunction = new DifferentiableFunction() { // from class: tberg.murphy.structpred.LBFGSSoftMaxMarginLearner.1
            @Override // tberg.murphy.opt.DifferentiableFunction
            public Pair<Double, double[]> calculate(double[] dArr) {
                IntCounter wrapArray = IntCounter.wrapArray(dArr, LBFGSSoftMaxMarginLearner.this.numFeatures);
                summingLossAugmentedLinearModel.setWeights(wrapArray);
                List<UpdateBundle> expectedLossAugmentedUpdateBundleBatch = summingLossAugmentedLinearModel.getExpectedLossAugmentedUpdateBundleBatch(list, 1.0d);
                double d = 0.0d;
                double[] dArr2 = new double[LBFGSSoftMaxMarginLearner.this.numFeatures];
                for (UpdateBundle updateBundle : expectedLossAugmentedUpdateBundleBatch) {
                    for (Map.Entry<Integer, Double> entry : updateBundle.gold.entries()) {
                        int intValue = entry.getKey().intValue();
                        dArr2[intValue] = dArr2[intValue] - entry.getValue().doubleValue();
                    }
                    for (Map.Entry<Integer, Double> entry2 : updateBundle.guess.entries()) {
                        int intValue2 = entry2.getKey().intValue();
                        dArr2[intValue2] = dArr2[intValue2] + entry2.getValue().doubleValue();
                    }
                    d += (-updateBundle.gold.dotProduct(wrapArray)) + updateBundle.loss;
                }
                a.combi(dArr2, 1.0d, dArr, 2.0d * LBFGSSoftMaxMarginLearner.this.C);
                return Pair.makePair(Double.valueOf(d + (LBFGSSoftMaxMarginLearner.this.C * a.sum(a.sqr(dArr)))), dArr2);
            }
        };
        LBFGSMinimizer lBFGSMinimizer = new LBFGSMinimizer(this.tolerance, i);
        double[] dArr = new double[this.numFeatures];
        for (Map.Entry<Integer, Double> entry : counterInterface.entries()) {
            dArr[entry.getKey().intValue()] = entry.getValue().doubleValue();
        }
        return IntCounter.wrapArray(lBFGSMinimizer.minimize(differentiableFunction, dArr, true, null), this.numFeatures);
    }
}
