package tberg.murphy.classifier;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import tberg.murphy.arrays.a;
import tberg.murphy.counter.Counter;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.opt.DifferentiableFunction;
import tberg.murphy.opt.LBFGSMinimizer;
import tberg.murphy.threading.BetterThreader;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/classifier/LogisticRegressionClassifier.class */
public class LogisticRegressionClassifier implements Classifier {
    private double tol;
    private double regConstant;
    private int maxIters;
    private int numLabels;
    private int numFeatures;
    private int numThreads;
    private float[] weights;

    public LogisticRegressionClassifier(double d, double d2, int i, int i2) {
        this.regConstant = d;
        this.tol = d2;
        this.maxIters = i;
        this.numThreads = i2;
    }

    @Override // tberg.murphy.classifier.Classifier
    public void train(final List<Pair<CounterInterface<Integer>, Integer>> list) {
        int i = 0;
        int i2 = 0;
        for (Pair<CounterInterface<Integer>, Integer> pair : list) {
            i = Math.max(i, pair.getSecond().intValue());
            Iterator<Integer> it = pair.getFirst().keySet().iterator();
            while (it.hasNext()) {
                i2 = Math.max(i2, it.next().intValue());
            }
        }
        this.numLabels = i + 1;
        this.numFeatures = i2 + 1;
        this.weights = new float[this.numFeatures * this.numLabels];
        this.weights = a.toFloat(new LBFGSMinimizer(this.tol, this.maxIters, true).minimize(new DifferentiableFunction() { // from class: tberg.murphy.classifier.LogisticRegressionClassifier.1
            @Override // tberg.murphy.opt.DifferentiableFunction
            public Pair<Double, double[]> calculate(double[] dArr) {
                final float[] fArr = a.toFloat(dArr);
                final float[] fArr2 = new float[LogisticRegressionClassifier.this.numThreads];
                final float[][] fArr3 = new float[LogisticRegressionClassifier.this.numThreads][fArr.length];
                BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Pair<CounterInterface<Integer>, Integer>, Integer>() { // from class: tberg.murphy.classifier.LogisticRegressionClassifier.1.1
                    @Override // tberg.murphy.threading.BetterThreader.Function
                    public void call(Pair<CounterInterface<Integer>, Integer> pair2, Integer num) {
                        float[] exp = a.exp(LogisticRegressionClassifier.this.getScores(fArr, pair2.getFirst()));
                        a.normalizei(exp);
                        fArr2[num.intValue()] = (float) (r0[r1] - Math.log(exp[pair2.getSecond().intValue()]));
                        for (Map.Entry<Integer, Double> entry : pair2.getFirst().entries()) {
                            int intValue = entry.getKey().intValue();
                            float doubleValue = (float) entry.getValue().doubleValue();
                            for (int i3 = 0; i3 < LogisticRegressionClassifier.this.numLabels; i3++) {
                                float[] fArr4 = fArr3[num.intValue()];
                                int i4 = (intValue * LogisticRegressionClassifier.this.numLabels) + i3;
                                fArr4[i4] = fArr4[i4] + (doubleValue * exp[i3]);
                            }
                            float[] fArr5 = fArr3[num.intValue()];
                            int intValue2 = (intValue * LogisticRegressionClassifier.this.numLabels) + pair2.getSecond().intValue();
                            fArr5[intValue2] = fArr5[intValue2] - doubleValue;
                        }
                    }
                }, LogisticRegressionClassifier.this.numThreads);
                for (int i3 = 0; i3 < LogisticRegressionClassifier.this.numThreads; i3++) {
                    betterThreader.setThreadArgument(i3, Integer.valueOf(i3));
                }
                Iterator it2 = list.iterator();
                while (it2.hasNext()) {
                    betterThreader.addFunctionArgument((Pair) it2.next());
                }
                betterThreader.run();
                float[] scale = a.scale(fArr, 2.0f * ((float) LogisticRegressionClassifier.this.regConstant));
                for (int i4 = 0; i4 < LogisticRegressionClassifier.this.numThreads; i4++) {
                    a.combi(scale, 1.0f, fArr3[i4], 1.0f);
                }
                return Pair.makePair(Double.valueOf(a.sum(fArr2) + (LogisticRegressionClassifier.this.regConstant * a.sum(a.sqr(fArr)))), a.toDouble(scale));
            }
        }, a.toDouble(this.weights), true, null));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public float[] getScores(float[] fArr, CounterInterface<Integer> counterInterface) {
        float[] fArr2 = new float[this.numLabels];
        for (Map.Entry<Integer, Double> entry : counterInterface.entries()) {
            int intValue = entry.getKey().intValue();
            float doubleValue = (float) entry.getValue().doubleValue();
            for (int i = 0; i < this.numLabels; i++) {
                int i2 = i;
                fArr2[i2] = fArr2[i2] + (doubleValue * fArr[(intValue * this.numLabels) + i]);
            }
        }
        return fArr2;
    }

    @Override // tberg.murphy.classifier.Classifier
    public Map<Integer, CounterInterface<Integer>> getWeights() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.numLabels; i++) {
            for (int i2 = 0; i2 < this.numFeatures; i2++) {
                CounterInterface counterInterface = (CounterInterface) hashMap.get(Integer.valueOf(i));
                if (counterInterface == null) {
                    counterInterface = new Counter();
                    hashMap.put(Integer.valueOf(i), counterInterface);
                }
                counterInterface.setCount(Integer.valueOf(i2), this.weights[(i2 * this.numLabels) + i]);
            }
        }
        return hashMap;
    }

    @Override // tberg.murphy.classifier.Classifier
    public Integer predict(CounterInterface<Integer> counterInterface) {
        return Integer.valueOf(a.argmax(getScores(this.weights, counterInterface)));
    }
}
