package tberg.murphy.regressor;

import org.jblas.FloatMatrix;
import org.jblas.Solve;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/regressor/KernelPLSRegressor.class */
public class KernelPLSRegressor implements Regressor {
    float[][] x;
    KernelMatrixBuilder kernelBuilder;
    float reg;
    FloatMatrix alpha;
    FloatMatrix Kcolsum;
    float Ksum;
    int k;
    float tol;

    public KernelPLSRegressor(KernelMatrixBuilder kernelMatrixBuilder, int i, float f) {
        this.k = i;
        this.tol = f;
        this.kernelBuilder = kernelMatrixBuilder;
    }

    @Override // tberg.murphy.regressor.Regressor
    public void train(float[][] fArr, float[][] fArr2) {
        this.x = fArr;
        FloatMatrix floatMatrix = new FloatMatrix(this.kernelBuilder.build(fArr, fArr));
        this.Kcolsum = floatMatrix.rowSums();
        this.Ksum = this.Kcolsum.sum();
        floatMatrix.addiColumnVector(this.Kcolsum.mul((-1.0f) / fArr.length));
        floatMatrix.addiRowVector(this.Kcolsum.transpose().mul((-1.0f) / fArr.length));
        floatMatrix.addi(this.Ksum / (fArr.length * fArr.length));
        FloatMatrix dup = floatMatrix.dup();
        FloatMatrix floatMatrix2 = new FloatMatrix(fArr2);
        FloatMatrix dup2 = floatMatrix2.dup();
        FloatMatrix zeros = FloatMatrix.zeros(fArr2.length, this.k);
        FloatMatrix zeros2 = FloatMatrix.zeros(fArr.length, this.k);
        for (int i = 0; i < this.k; i++) {
            FloatMatrix column = dup2.getColumn(0);
            column.divi(column.norm2());
            if (floatMatrix2.columns > 1) {
                FloatMatrix mmul = dup2.mmul(dup2.transpose().mmul(dup));
                FloatMatrix floatMatrix3 = null;
                while (true) {
                    if (floatMatrix3 == null || floatMatrix3.distance2(column) > this.tol) {
                        floatMatrix3 = column;
                        column = mmul.mmul(column);
                        column.divi(column.norm2());
                    }
                }
            }
            FloatMatrix mmul2 = dup.mmul(column);
            float norm2 = mmul2.norm2();
            float f = norm2 * norm2;
            dup2.subi(mmul2.mmul(dup2.transpose().mmul(mmul2).divi(f).transpose()));
            FloatMatrix addi = mmul2.mmul(mmul2.transpose()).divi(-f).addi(FloatMatrix.eye(fArr.length));
            dup = addi.mmul(dup).mmul(addi);
            zeros.putColumn(i, column);
            zeros2.putColumn(i, mmul2);
        }
        this.alpha = zeros.mmul(Solve.solve(zeros2.transpose().mmul(floatMatrix).mmul(zeros), zeros2.transpose().mmul(floatMatrix2))).transpose();
    }

    @Override // tberg.murphy.regressor.Regressor
    public float[][] predict(float[][] fArr) {
        FloatMatrix floatMatrix = new FloatMatrix(this.kernelBuilder.build(this.x, fArr));
        floatMatrix.addiColumnVector(this.Kcolsum.mul((-1.0f) / this.x.length));
        floatMatrix.addiRowVector(floatMatrix.columnSums().mul((-1.0f) / this.x.length));
        floatMatrix.addi(this.Ksum / (this.x.length * this.x.length));
        return this.alpha.mmul(floatMatrix).transpose().toArray2();
    }
}
