package tberg.murphy.structpred;

import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import jcuda.runtime.cudaError;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/structpred/NSlackSVMSimpleLearner.class */
public class NSlackSVMSimpleLearner<T> implements LossAugmentedLearner<T> {
    int maxLength;
    double C;
    int N;
    double epsilon;
    List<IntCounter>[] indexToDelta;
    List<Double>[] indexToAlpha;
    List<Double>[] indexToLoss;
    double[][] dotProdCache;
    private SvmOpts opts;

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/structpred/NSlackSVMSimpleLearner$SvmOpts.class */
    public static class SvmOpts {
        public double EPSILON = 1.0E-20d;
        public double SMO_TOL = 1.0E-4d;
        public int SMO_ITERS = cudaError.cudaErrorApiFailureBase;
        public boolean REFRESH_ALPHAS = false;
        public double NEW_ALPHA_MAG = Cropper.VERT_GROW_RATIO;
        public boolean smoCheckPrimal = true;
    }

    public NSlackSVMSimpleLearner(double d, double d2) {
        this(d, d2, new SvmOpts());
    }

    public NSlackSVMSimpleLearner(double d, double d2, SvmOpts svmOpts) {
        this.C = d;
        this.epsilon = d2;
        this.opts = svmOpts;
    }

    @Override // tberg.murphy.structpred.LossAugmentedLearner
    public CounterInterface<Integer> train(CounterInterface<Integer> counterInterface, LossAugmentedLinearModel<T> lossAugmentedLinearModel, List<T> list, int i) {
        CounterInterface<Integer> counterInterface2 = counterInterface;
        int i2 = 0;
        while (i2 < i) {
            lossAugmentedLinearModel.startIteration(i2);
            int reapConstraints = i2 == 0 ? reapConstraints(true, lossAugmentedLinearModel, list, counterInterface2) : reapConstraints(false, lossAugmentedLinearModel, list, counterInterface2);
            System.out.printf("Added %d contraints.\n", Integer.valueOf(reapConstraints));
            System.out.printf("Iteration %d...\n", Integer.valueOf(i2));
            System.out.printf("Num constraints: %d\n", Integer.valueOf(numConstraints()));
            if (this.opts.REFRESH_ALPHAS || i2 == 0) {
                uniformInitializeAlphas();
            }
            buildDotProdCache();
            optimizeDualObjectiveSMO();
            System.out.printf("Primal objective: %.8f\n", Double.valueOf(getPrimalObjective()));
            System.out.printf("Dual objective: %.8f\n", Double.valueOf(getDualObjective()));
            if (counterInterface2 != null) {
                CounterInterface<Integer> weights = getWeights();
                IntCounter intCounter = new IntCounter();
                intCounter.incrementAll(weights);
                intCounter.incrementAll(counterInterface2, -1.0d);
                System.out.printf("Mag of weights delta: %.8f\n", Double.valueOf(Math.sqrt(intCounter.dotProduct((CounterInterface) intCounter))));
                counterInterface2 = weights;
            } else {
                counterInterface2 = getWeights();
            }
            if (reapConstraints == 0) {
                break;
            }
            i2++;
        }
        lossAugmentedLinearModel.setWeights(getWeights());
        return getWeights();
    }

    Pair<Integer, Integer> getAlphaRelativeIndicesFromAbsoluteIndex(int i) {
        int i2 = 0;
        while (i >= numConstraints(i2)) {
            i -= numConstraints(i2);
            i2++;
        }
        return Pair.makePair(Integer.valueOf(i2), Integer.valueOf(i));
    }

    int getAlphaAbsoluteIndexFromRelativeIndices(int i, int i2) {
        int i3 = 0;
        for (int i4 = 0; i4 < i; i4++) {
            i3 += numConstraints(i4);
        }
        return i3 + i2;
    }

    public double getPrimalObjective() {
        CounterInterface<Integer> weights = getWeights();
        double dotProduct = Cropper.VERT_GROW_RATIO + (0.5d * weights.dotProduct(weights));
        for (int i = 0; i < this.indexToAlpha.length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < numConstraints(i); i2++) {
                d = Math.max(d, getContraintSlack(i, i2, weights));
            }
            dotProduct += (this.C / this.indexToAlpha.length) * d;
        }
        return dotProduct;
    }

    double getContraintSlack(int i, int i2, CounterInterface<Integer> counterInterface) {
        return this.indexToLoss[i].get(i2).doubleValue() - this.indexToDelta[i].get(i2).dotProduct(counterInterface);
    }

    double getContraintSlack(CounterInterface<Integer> counterInterface, CounterInterface<Integer> counterInterface2, double d) {
        return d - counterInterface.dotProduct(counterInterface2);
    }

    public double getDualObjective() {
        double d = 0.0d;
        for (int i = 0; i < this.indexToAlpha.length; i++) {
            for (int i2 = 0; i2 < numConstraints(i); i2++) {
                d += (this.C / this.indexToAlpha.length) * this.indexToAlpha[i].get(i2).doubleValue() * this.indexToLoss[i].get(i2).doubleValue();
                for (int i3 = 0; i3 < this.indexToAlpha.length; i3++) {
                    for (int i4 = 0; i4 < numConstraints(i3); i4++) {
                        d -= ((((0.5d * (this.C / this.indexToAlpha.length)) * (this.C / this.indexToAlpha.length)) * this.indexToAlpha[i].get(i2).doubleValue()) * this.indexToAlpha[i3].get(i4).doubleValue()) * this.dotProdCache[getAlphaAbsoluteIndexFromRelativeIndices(i, i2)][getAlphaAbsoluteIndexFromRelativeIndices(i3, i4)];
                    }
                }
            }
        }
        return d;
    }

    double[] getDualGradient(int i) {
        double[] dArr = new double[numConstraints(i)];
        for (int i2 = 0; i2 < numConstraints(i); i2++) {
            dArr[i2] = (this.C / this.indexToAlpha.length) * this.indexToLoss[i].get(i2).doubleValue();
            for (int i3 = 0; i3 < this.indexToAlpha.length; i3++) {
                for (int i4 = 0; i4 < numConstraints(i3); i4++) {
                    int i5 = i2;
                    dArr[i5] = dArr[i5] - ((((this.C / this.indexToAlpha.length) * (this.C / this.indexToAlpha.length)) * this.indexToAlpha[i3].get(i4).doubleValue()) * this.dotProdCache[getAlphaAbsoluteIndexFromRelativeIndices(i, i2)][getAlphaAbsoluteIndexFromRelativeIndices(i3, i4)]);
                }
            }
        }
        return dArr;
    }

    static void normalize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            if (d > Cropper.VERT_GROW_RATIO) {
                int i2 = i;
                dArr[i2] = dArr[i2] / d;
            }
        }
    }

    void setAlphas(int i, double[] dArr) {
        for (int i2 = 0; i2 < numConstraints(i); i2++) {
            this.indexToAlpha[i].set(i2, Double.valueOf(dArr[i2]));
        }
    }

    double[] getAlphas(int i) {
        double[] dArr = new double[numConstraints(i)];
        for (int i2 = 0; i2 < numConstraints(i); i2++) {
            dArr[i2] = this.indexToAlpha[i].get(i2).doubleValue();
        }
        return dArr;
    }

    public void optimizeDualObjectiveSMO() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 1; i <= this.opts.SMO_ITERS; i++) {
            for (int i2 = 0; i2 < this.indexToAlpha.length; i2++) {
                for (int i3 = 0; i3 < numConstraints(i2); i3++) {
                    for (int i4 = 0; i4 < numConstraints(i2); i4++) {
                        if (i3 != i4) {
                            updateAlphas(i2, i3, i4);
                        }
                    }
                }
            }
            double dualObjective = getDualObjective();
            double primalObjective = this.opts.smoCheckPrimal ? getPrimalObjective() : Double.NaN;
            if (i == 1 || converged(primalObjective, dualObjective, this.opts.SMO_TOL, d) || i == this.opts.SMO_ITERS) {
                System.out.printf("[SMO] Round %d: %.8f\n", Integer.valueOf(i), Double.valueOf(dualObjective));
            }
            if (converged(primalObjective, dualObjective, this.opts.SMO_TOL, d)) {
                return;
            }
            d = dualObjective;
        }
    }

    boolean converged(double d, double d2, double d3, double d4) {
        if (this.opts.smoCheckPrimal) {
            return Math.abs(d - d2) < this.opts.EPSILON || Math.abs(d - d2) / ((Math.abs(d2) + Math.abs(d)) / 2.0d) < d3;
        }
        return (d2 - d4) / d2 < this.opts.SMO_TOL;
    }

    public void updateAlphas(int i, int i2, int i3) {
        int alphaAbsoluteIndexFromRelativeIndices = getAlphaAbsoluteIndexFromRelativeIndices(i, i2);
        int alphaAbsoluteIndexFromRelativeIndices2 = getAlphaAbsoluteIndexFromRelativeIndices(i, i3);
        if (this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices][alphaAbsoluteIndexFromRelativeIndices] == Cropper.VERT_GROW_RATIO && this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices2][alphaAbsoluteIndexFromRelativeIndices2] == Cropper.VERT_GROW_RATIO) {
            return;
        }
        double doubleValue = this.indexToLoss[i].get(i2).doubleValue() - this.indexToLoss[i].get(i3).doubleValue();
        for (int i4 = 0; i4 < this.indexToAlpha.length; i4++) {
            for (int i5 = 0; i5 < numConstraints(i4); i5++) {
                int alphaAbsoluteIndexFromRelativeIndices3 = getAlphaAbsoluteIndexFromRelativeIndices(i4, i5);
                doubleValue = (doubleValue - (((this.C / this.indexToAlpha.length) * this.indexToAlpha[i4].get(i5).doubleValue()) * this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices][alphaAbsoluteIndexFromRelativeIndices3])) + ((this.C / this.indexToAlpha.length) * this.indexToAlpha[i4].get(i5).doubleValue() * this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices2][alphaAbsoluteIndexFromRelativeIndices3]);
            }
        }
        double length = ((Cropper.VERT_GROW_RATIO + ((this.C / this.indexToAlpha.length) * this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices][alphaAbsoluteIndexFromRelativeIndices])) - ((2.0d * (this.C / this.indexToAlpha.length)) * this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices][alphaAbsoluteIndexFromRelativeIndices2])) + ((this.C / this.indexToAlpha.length) * this.dotProdCache[alphaAbsoluteIndexFromRelativeIndices2][alphaAbsoluteIndexFromRelativeIndices2]);
        if (length == Cropper.VERT_GROW_RATIO) {
            return;
        }
        double max = Math.max(-this.indexToAlpha[i].get(i2).doubleValue(), Math.min(this.indexToAlpha[i].get(i3).doubleValue(), doubleValue / length));
        this.indexToAlpha[i].set(i2, Double.valueOf(this.indexToAlpha[i].get(i2).doubleValue() + max));
        this.indexToAlpha[i].set(i3, Double.valueOf(this.indexToAlpha[i].get(i3).doubleValue() - max));
    }

    void clearConstraints(int i) {
        this.indexToDelta = new List[i];
        this.indexToAlpha = new List[i];
        this.indexToLoss = new List[i];
        for (int i2 = 0; i2 < i; i2++) {
            this.indexToDelta[i2] = new ArrayList();
            this.indexToAlpha[i2] = new ArrayList();
            this.indexToLoss[i2] = new ArrayList();
        }
    }

    public int reapConstraints(boolean z, LossAugmentedLinearModel<T> lossAugmentedLinearModel, List<T> list, CounterInterface<Integer> counterInterface) {
        if (z) {
            clearConstraints(list.size());
            for (int i = 0; i < list.size(); i++) {
                addConstraint(i, new IntCounter(), Cropper.VERT_GROW_RATIO);
            }
        }
        lossAugmentedLinearModel.setWeights(counterInterface);
        List<UpdateBundle> lossAugmentedUpdateBundleBatch = lossAugmentedLinearModel.getLossAugmentedUpdateBundleBatch(list, 1.0d);
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            UpdateBundle updateBundle = lossAugmentedUpdateBundleBatch.get(i3);
            IntCounter intCounter = new IntCounter();
            intCounter.incrementAll(updateBundle.gold);
            intCounter.incrementAll(updateBundle.guess, -1.0d);
            double d = updateBundle.loss;
            double d2 = Double.NEGATIVE_INFINITY;
            for (int i4 = 0; i4 < numConstraints(i3); i4++) {
                d2 = Math.max(d2, getContraintSlack(i3, i4, counterInterface));
            }
            if (getContraintSlack(counterInterface, intCounter, d) > d2 + this.epsilon) {
                addConstraint(i3, intCounter, d);
                i2++;
            }
        }
        return i2;
    }

    public void buildDotProdCache() {
        this.dotProdCache = new double[numConstraints()][numConstraints()];
        for (int i = 0; i < numConstraints(); i++) {
            for (int i2 = 0; i2 < numConstraints(); i2++) {
                Pair<Integer, Integer> alphaRelativeIndicesFromAbsoluteIndex = getAlphaRelativeIndicesFromAbsoluteIndex(i);
                Pair<Integer, Integer> alphaRelativeIndicesFromAbsoluteIndex2 = getAlphaRelativeIndicesFromAbsoluteIndex(i2);
                this.dotProdCache[i][i2] = this.indexToDelta[alphaRelativeIndicesFromAbsoluteIndex.getFirst().intValue()].get(alphaRelativeIndicesFromAbsoluteIndex.getSecond().intValue()).dotProduct(this.indexToDelta[alphaRelativeIndicesFromAbsoluteIndex2.getFirst().intValue()].get(alphaRelativeIndicesFromAbsoluteIndex2.getSecond().intValue()));
            }
        }
    }

    public void zeroInitializeAlphas() {
        for (int i = 0; i < this.indexToAlpha.length; i++) {
            for (int i2 = 0; i2 < numConstraints(i); i2++) {
                if (i2 == 0) {
                    this.indexToAlpha[i].set(i2, Double.valueOf(1.0d));
                } else {
                    this.indexToAlpha[i].set(i2, Double.valueOf(Cropper.VERT_GROW_RATIO));
                }
            }
        }
    }

    public void uniformInitializeAlphas() {
        for (int i = 0; i < this.indexToAlpha.length; i++) {
            for (int i2 = 0; i2 < numConstraints(i); i2++) {
                if (i2 == 0) {
                    this.indexToAlpha[i].set(i2, Double.valueOf(0.9d));
                } else {
                    this.indexToAlpha[i].set(i2, Double.valueOf(0.1d / (numConstraints(i) - 1.0d)));
                }
            }
        }
    }

    public void addConstraint(int i, CounterInterface<Integer> counterInterface, double d) {
        this.indexToAlpha[i].add(Double.valueOf(this.opts.NEW_ALPHA_MAG));
        normalizeAlphas(i);
        this.indexToDelta[i].add(new IntCounter(counterInterface));
        this.indexToLoss[i].add(Double.valueOf(d));
    }

    void normalizeAlphas(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < numConstraints(i); i2++) {
            d += this.indexToAlpha[i].get(i2).doubleValue();
        }
        for (int i3 = 0; i3 < numConstraints(i); i3++) {
            if (d > Cropper.VERT_GROW_RATIO) {
                this.indexToAlpha[i].set(i3, Double.valueOf(this.indexToAlpha[i].get(i3).doubleValue() / d));
            }
        }
    }

    public int numConstraints(int i) {
        if (this.indexToAlpha == null) {
            return 0;
        }
        return this.indexToAlpha[i].size();
    }

    public int numConstraints() {
        if (this.indexToAlpha == null) {
            return 0;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.indexToAlpha.length; i2++) {
            i += numConstraints(i2);
        }
        return i;
    }

    public CounterInterface<Integer> getWeights() {
        IntCounter intCounter = new IntCounter();
        for (int i = 0; i < this.indexToAlpha.length; i++) {
            for (int i2 = 0; i2 < numConstraints(i); i2++) {
                for (Map.Entry<Integer, Double> entry : this.indexToDelta[i].get(i2).entries()) {
                    intCounter.incrementCount((IntCounter) entry.getKey(), (this.C / this.indexToAlpha.length) * this.indexToAlpha[i].get(i2).doubleValue() * entry.getValue().doubleValue());
                }
            }
        }
        return intCounter;
    }

    static <D> int index(D d, List<D> list, Map<D, Integer> map) {
        Integer num = map.get(d);
        if (num == null) {
            num = Integer.valueOf(list.size());
            map.put(d, num);
            list.add(d);
        }
        return num.intValue();
    }
}
