package tberg.murphy.regressor;

import java.util.ArrayList;
import tberg.murphy.arrays.a;
import tberg.murphy.classifier.Classifier;
import tberg.murphy.classifier.LogisticRegressionClassifier;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/regressor/BinaryRegressor.class */
public class BinaryRegressor implements Regressor {
    Classifier[] classifiers;
    float thresh;
    float c0;
    float c1;
    float reg;

    public BinaryRegressor(float f, float f2, float f3, float f4) {
        this.thresh = f2;
        this.c0 = f3;
        this.c1 = f4;
        this.reg = f;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // tberg.murphy.regressor.Regressor
    public void train(float[][] fArr, float[][] fArr2) {
        float[] fArr3 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr3[i] = a.append(10.0f, fArr[i]);
        }
        this.classifiers = new Classifier[fArr2[0].length];
        for (int i2 = 0; i2 < this.classifiers.length; i2++) {
            this.classifiers[i2] = new LogisticRegressionClassifier(this.reg, 1.0E-20d, 100, 8);
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < fArr3.length; i3++) {
                IntCounter intCounter = new IntCounter();
                for (int i4 = 0; i4 < fArr3[i3].length; i4++) {
                    intCounter.setCount((IntCounter) Integer.valueOf(i4), fArr3[i3][i4]);
                }
                arrayList.add(Pair.makePair(intCounter, Integer.valueOf(fArr2[i3][i2] > this.thresh ? 1 : 0)));
            }
            this.classifiers[i2].train(arrayList);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // tberg.murphy.regressor.Regressor
    public float[][] predict(float[][] fArr) {
        float[] fArr2 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr2[i] = a.append(1.0f, fArr[i]);
        }
        float[][] fArr3 = new float[fArr.length][this.classifiers.length];
        for (int i2 = 0; i2 < fArr3.length; i2++) {
            IntCounter intCounter = new IntCounter();
            for (int i3 = 0; i3 < fArr2[i2].length; i3++) {
                intCounter.setCount((IntCounter) Integer.valueOf(i3), fArr2[i2][i3]);
            }
            for (int i4 = 0; i4 < this.classifiers.length; i4++) {
                if (this.classifiers[i4].predict(intCounter).intValue() == 0) {
                    fArr3[i2][i4] = this.c0;
                } else {
                    fArr3[i2][i4] = this.c1;
                }
            }
        }
        return fArr3;
    }
}
