package tberg.murphy.regressor;

import org.jblas.FloatMatrix;
import org.jblas.Solve;
import tberg.murphy.gpu.CublasUtil;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/regressor/GPUKernelPLSRegressor.class */
public class GPUKernelPLSRegressor implements Regressor {
    float[][] x;
    KernelMatrixBuilder kernelBuilder;
    float reg;
    CublasUtil.Matrix alpha;
    CublasUtil.Matrix Kcolsum;
    float Ksum;
    int k;
    float tol;

    public GPUKernelPLSRegressor(KernelMatrixBuilder kernelMatrixBuilder, int i, float f) {
        this.k = i;
        this.tol = f;
        this.kernelBuilder = kernelMatrixBuilder;
    }

    @Override // tberg.murphy.regressor.Regressor
    public void train(float[][] fArr, float[][] fArr2) {
        this.x = fArr;
        CublasUtil.Matrix build = CublasUtil.Matrix.build(this.kernelBuilder.build(fArr, fArr));
        CublasUtil.freeAllBut(build);
        this.Kcolsum = build.colSum();
        this.Kcolsum.setDontFree(true);
        this.Ksum = this.Kcolsum.rowSum().toArray()[0];
        build.colAddi(this.Kcolsum.mul((-1.0f) / fArr.length));
        build.rowAddi(this.Kcolsum.transpose().mul((-1.0f) / fArr.length));
        build.addi(this.Ksum / (fArr.length * fArr.length));
        CublasUtil.freeAllBut(build);
        CublasUtil.Matrix copy = build.copy();
        CublasUtil.Matrix build2 = CublasUtil.Matrix.build(fArr2);
        CublasUtil.Matrix copy2 = build2.copy();
        CublasUtil.Matrix zeros = CublasUtil.Matrix.zeros(fArr2.length, this.k);
        CublasUtil.Matrix zeros2 = CublasUtil.Matrix.zeros(fArr.length, this.k);
        for (int i = 0; i < this.k; i++) {
            CublasUtil.Matrix copyCol = copy2.copyCol(0);
            copyCol.muli(1.0f / copyCol.norm2());
            if (build2.cols() > 1) {
                CublasUtil.Matrix mmul = copy2.mmul(copy2.transpose().mmul(copy));
                CublasUtil.freeAllBut(copyCol, mmul, copy, build, copy2, build2, zeros, zeros2);
                CublasUtil.Matrix matrix = null;
                while (true) {
                    if (matrix != null && matrix.distance2(copyCol) <= this.tol) {
                        break;
                    }
                    matrix = copyCol;
                    copyCol = mmul.mmul(copyCol);
                    copyCol.muli(1.0f / copyCol.norm2());
                }
                CublasUtil.freeAllBut(copyCol, copy, build, copy2, build2, zeros, zeros2);
            }
            zeros.setCol(i, copyCol);
            CublasUtil.Matrix mmul2 = copy.mmul(copyCol);
            zeros2.setCol(i, mmul2);
            float norm2 = mmul2.norm2();
            float f = norm2 * norm2;
            copy2.subi(mmul2.mmul(copy2.transpose().mmul(mmul2).muli(1.0f / f).transpose()));
            CublasUtil.freeAllBut(mmul2, copy, build, copy2, build2, zeros, zeros2);
            CublasUtil.Matrix addi = mmul2.mmul(mmul2.transpose()).muli((-1.0f) / f).addi(CublasUtil.Matrix.eye(fArr.length));
            CublasUtil.freeAllBut(addi, copy, build, copy2, build2, zeros, zeros2);
            copy = addi.mmul(copy).mmul(addi);
            CublasUtil.freeAllBut(copy, build, copy2, build2, zeros, zeros2);
        }
        this.alpha = zeros.mmul(CublasUtil.Matrix.build(Solve.solve(new FloatMatrix(zeros2.transpose().mmul(build).mmul(zeros).toArray2()), new FloatMatrix(zeros2.transpose().mmul(build2).toArray2())).toArray2())).transpose();
        this.alpha.setDontFree(true);
        CublasUtil.freeAll();
    }

    @Override // tberg.murphy.regressor.Regressor
    public float[][] predict(float[][] fArr) {
        CublasUtil.Matrix build = CublasUtil.Matrix.build(this.kernelBuilder.build(this.x, fArr));
        build.colAddi(this.Kcolsum.mul((-1.0f) / this.x.length));
        build.rowAddi(build.rowSum().mul((-1.0f) / this.x.length));
        build.addi(this.Ksum / (this.x.length * this.x.length));
        float[][] array2 = this.alpha.mmul(build).transpose().toArray2();
        CublasUtil.freeAll();
        return array2;
    }
}
