package tberg.murphy.opt;

import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate;
import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.Arrays;
import java.util.LinkedList;
import tberg.murphy.arrays.a;
import tberg.murphy.opt.Minimizer;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/opt/LBFGSMinimizer.class */
public class LBFGSMinimizer implements Minimizer {
    private static final double EPS = 1.0E-12d;
    private static final double LINE_SEARCH_SUFF_DECR = 1.0E-4d;
    int maxHistorySize;
    double initialStepSizeMultiplier;
    double stepSizeMultiplier;
    double stepSizeGrowMultiplier;
    boolean finishOnFirstConverge;
    double tolerance;
    int maxIters;
    LinkedList<double[]> inputDifferenceVectorList;
    LinkedList<double[]> derivativeDifferenceVectorList;

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/opt/LBFGSMinimizer$CachingFunctionWrapper.class */
    private static class CachingFunctionWrapper implements DifferentiableFunction {
        DifferentiableFunction func;
        double[] x;
        Pair<Double, double[]> valAndGrad;

        public CachingFunctionWrapper(DifferentiableFunction differentiableFunction) {
            this.func = differentiableFunction;
        }

        private void ensureCache(double[] dArr) {
            if (this.x == null || !Arrays.equals(this.x, dArr)) {
                this.valAndGrad = this.func.calculate(dArr);
            }
        }

        @Override // tberg.murphy.opt.DifferentiableFunction
        public Pair<Double, double[]> calculate(double[] dArr) {
            ensureCache(dArr);
            return this.valAndGrad;
        }
    }

    public LBFGSMinimizer(double d, int i, boolean z, double d2, double d3, double d4, int i2) {
        this.maxHistorySize = 5;
        this.initialStepSizeMultiplier = 0.5d;
        this.stepSizeMultiplier = 0.5d;
        this.stepSizeGrowMultiplier = 1.5d;
        this.finishOnFirstConverge = true;
        this.tolerance = d;
        this.maxIters = i;
        this.finishOnFirstConverge = z;
        this.initialStepSizeMultiplier = d2;
        this.stepSizeMultiplier = d3;
        this.stepSizeGrowMultiplier = d4;
        this.maxHistorySize = i2;
    }

    public LBFGSMinimizer(double d, int i, boolean z, double d2, double d3, double d4) {
        this.maxHistorySize = 5;
        this.initialStepSizeMultiplier = 0.5d;
        this.stepSizeMultiplier = 0.5d;
        this.stepSizeGrowMultiplier = 1.5d;
        this.finishOnFirstConverge = true;
        this.tolerance = d;
        this.maxIters = i;
        this.finishOnFirstConverge = z;
        this.initialStepSizeMultiplier = d2;
        this.stepSizeMultiplier = d3;
        this.stepSizeGrowMultiplier = d4;
    }

    public LBFGSMinimizer(double d, int i, boolean z) {
        this.maxHistorySize = 5;
        this.initialStepSizeMultiplier = 0.5d;
        this.stepSizeMultiplier = 0.5d;
        this.stepSizeGrowMultiplier = 1.5d;
        this.finishOnFirstConverge = true;
        this.tolerance = d;
        this.maxIters = i;
        this.finishOnFirstConverge = z;
    }

    public LBFGSMinimizer(double d, int i) {
        this.maxHistorySize = 5;
        this.initialStepSizeMultiplier = 0.5d;
        this.stepSizeMultiplier = 0.5d;
        this.stepSizeGrowMultiplier = 1.5d;
        this.finishOnFirstConverge = true;
        this.tolerance = d;
        this.maxIters = i;
    }

    @Override // tberg.murphy.opt.Minimizer
    public double[] minimize(DifferentiableFunction differentiableFunction, double[] dArr, boolean z, Minimizer.Callback callback) {
        this.inputDifferenceVectorList = new LinkedList<>();
        this.derivativeDifferenceVectorList = new LinkedList<>();
        CachingFunctionWrapper cachingFunctionWrapper = new CachingFunctionWrapper(differentiableFunction);
        double[] copy = a.copy(dArr);
        Pair<Double, double[]> calculate = cachingFunctionWrapper.calculate(copy);
        double doubleValue = calculate.getFirst().doubleValue();
        double[] second = calculate.getSecond();
        boolean z2 = false;
        double d = 1.0d;
        int i = 0;
        while (i < this.maxIters) {
            double[] implicitMultiply = implicitMultiply(getInitialInverseHessianDiagonal(cachingFunctionWrapper, dArr.length), second);
            a.scalei(implicitMultiply, -1.0d);
            Pair<double[], Double> lineSearch = lineSearch(cachingFunctionWrapper, copy, implicitMultiply, d, i == 0 ? this.initialStepSizeMultiplier : this.stepSizeMultiplier);
            if (lineSearch == null) {
                clearHistories();
                if (z) {
                    System.out.println("[LBFGSMinimizer.minimize] Cleared history.");
                }
                lineSearch = lineSearch(cachingFunctionWrapper, copy, implicitMultiply, 1.0d, this.stepSizeMultiplier);
                if (lineSearch == null) {
                    throw new Error("[LBFGSMinimizer.minimize] Cannot find step that will decrease function value.");
                }
            }
            d = lineSearch.getSecond().doubleValue() * this.stepSizeGrowMultiplier;
            double[] first = lineSearch.getFirst();
            Pair<Double, double[]> calculate2 = cachingFunctionWrapper.calculate(first);
            double doubleValue2 = calculate2.getFirst().doubleValue();
            double[] second2 = calculate2.getSecond();
            if (z) {
                System.out.println(String.format("[LBFGSMinimizer.minimize] Iteration %d ended with value %.6f", Integer.valueOf(i), Double.valueOf(doubleValue2)));
            }
            if (!converged(doubleValue, doubleValue2, this.tolerance)) {
                z2 = false;
            } else {
                if (this.finishOnFirstConverge || z2) {
                    return first;
                }
                clearHistories();
                if (z) {
                    System.out.println("[LBFGSMinimizer.minimize] Cleared history.");
                }
                d = 1.0d;
                z2 = true;
            }
            updateHistories(copy, first, second, second2);
            copy = first;
            doubleValue = doubleValue2;
            second = second2;
            if (callback != null) {
                callback.callback(copy, i, doubleValue, second);
            }
            i++;
        }
        if (z) {
            System.out.println("[LBFGSMinimizer.minimize] Exceeded max iterations without converging.");
        }
        return copy;
    }

    private static Pair<double[], Double> lineSearch(DifferentiableFunction differentiableFunction, double[] dArr, double[] dArr2, double d, double d2) {
        double d3 = d;
        Pair<Double, double[]> calculate = differentiableFunction.calculate(dArr);
        double doubleValue = calculate.getFirst().doubleValue();
        double[] second = calculate.getSecond();
        double innerProd = a.innerProd(second, dArr2);
        double max = a.max(a.abs(second));
        double[] dArr3 = null;
        boolean z = false;
        while (!z) {
            dArr3 = a.comb(dArr, 1.0d, dArr2, d3);
            double doubleValue2 = differentiableFunction.calculate(dArr3).getFirst().doubleValue();
            double d4 = doubleValue + (LINE_SEARCH_SUFF_DECR * innerProd * d3);
            z = doubleValue2 <= d4 + EPS;
            if (!z) {
                if (d3 < EPS && d3 * max < EPS) {
                    System.out.printf("[LBFGSMinimizer.minimize]: Line search step size underflow: %.15f, %.15f, %.15f, %.15f, %.15f, %.15f\n", Double.valueOf(d3), Double.valueOf(innerProd), Double.valueOf(max), Double.valueOf(doubleValue2), Double.valueOf(d4), Double.valueOf(doubleValue));
                    return null;
                }
                d3 *= d2;
            }
        }
        return Pair.makePair(dArr3, Double.valueOf(d3));
    }

    private boolean converged(double d, double d2, double d3) {
        double abs = Math.abs(d2 - d);
        return abs <= EPS || abs / (((Math.abs(d2) + Math.abs(d)) + EPS) / 2.0d) < d3;
    }

    private void updateHistories(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double[] comb = a.comb(dArr2, 1.0d, dArr, -1.0d);
        double[] comb2 = a.comb(dArr4, 1.0d, dArr3, -1.0d);
        pushOntoList(comb, this.inputDifferenceVectorList);
        pushOntoList(comb2, this.derivativeDifferenceVectorList);
    }

    private void pushOntoList(double[] dArr, LinkedList<double[]> linkedList) {
        linkedList.addFirst(dArr);
        if (linkedList.size() > this.maxHistorySize) {
            linkedList.removeLast();
        }
    }

    private void clearHistories() {
        this.inputDifferenceVectorList.clear();
        this.derivativeDifferenceVectorList.clear();
    }

    private int historySize() {
        return this.inputDifferenceVectorList.size();
    }

    private double[] getInputDifference(int i) {
        return this.inputDifferenceVectorList.get(i);
    }

    private double[] getDerivativeDifference(int i) {
        return this.derivativeDifferenceVectorList.get(i);
    }

    private double[] getLastDerivativeDifference() {
        return this.derivativeDifferenceVectorList.getFirst();
    }

    private double[] getLastInputDifference() {
        return this.inputDifferenceVectorList.getFirst();
    }

    private double[] implicitMultiply(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[historySize()];
        double[] dArr4 = new double[historySize()];
        double[] copy = a.copy(dArr2);
        for (int historySize = historySize() - 1; historySize >= 0; historySize--) {
            double[] inputDifference = getInputDifference(historySize);
            double[] derivativeDifference = getDerivativeDifference(historySize);
            dArr3[historySize] = a.innerProd(inputDifference, derivativeDifference);
            if (dArr3[historySize] == Cropper.VERT_GROW_RATIO) {
                throw new RuntimeException("LBFGSMinimizer.implicitMultiply: Curvature problem.");
            }
            dArr4[historySize] = a.innerProd(inputDifference, copy) / dArr3[historySize];
            copy = a.comb(copy, 1.0d, derivativeDifference, (-1.0d) * dArr4[historySize]);
        }
        double[] pointwiseMult = a.pointwiseMult(dArr, copy);
        for (int i = 0; i < historySize(); i++) {
            pointwiseMult = a.comb(pointwiseMult, 1.0d, getInputDifference(i), dArr4[i] - (a.innerProd(getDerivativeDifference(i), pointwiseMult) / dArr3[i]));
        }
        return pointwiseMult;
    }

    private double[] getInitialInverseHessianDiagonal(DifferentiableFunction differentiableFunction, int i) {
        double d = 1.0d;
        if (this.derivativeDifferenceVectorList.size() >= 1) {
            double[] lastDerivativeDifference = getLastDerivativeDifference();
            d = a.innerProd(lastDerivativeDifference, getLastInputDifference()) / a.innerProd(lastDerivativeDifference, lastDerivativeDifference);
        }
        double[] dArr = new double[i];
        Arrays.fill(dArr, d);
        return dArr;
    }

    public static void main(String[] strArr) {
        new LBFGSMinimizer(1.0E-5d, CharacterTemplate.INIT_LBFGS_ITERS).minimize(new DifferentiableFunction() { // from class: tberg.murphy.opt.LBFGSMinimizer.1
            @Override // tberg.murphy.opt.DifferentiableFunction
            public Pair<Double, double[]> calculate(double[] dArr) {
                return Pair.makePair(Double.valueOf((-a.sum(dArr)) + (2.0d * a.innerProd(dArr, dArr))), a.comb(a.scale(a.onesDouble(dArr.length), -1.0d), 1.0d, a.scale(dArr, 4.0d), 1.0d));
            }
        }, a.zerosDouble(10), true, null);
    }
}
