package tberg.murphy.regressor;

import org.jblas.FloatMatrix;
import org.jblas.Solve;
import tberg.murphy.gpu.CublasUtil;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/regressor/KernelLinearRegressor.class */
public class KernelLinearRegressor implements Regressor {
    float[][] x;
    KernelMatrixBuilder kernelBuilder;
    float reg;
    FloatMatrix alpha;
    FloatMatrix Kcolsum;
    float Ksum;

    public KernelLinearRegressor(KernelMatrixBuilder kernelMatrixBuilder, float f) {
        this.reg = f;
        this.kernelBuilder = kernelMatrixBuilder;
    }

    @Override // tberg.murphy.regressor.Regressor
    public void train(float[][] fArr, float[][] fArr2) {
        this.x = fArr;
        FloatMatrix floatMatrix = new FloatMatrix(this.kernelBuilder.build(fArr, fArr));
        CublasUtil.freeAll();
        this.Kcolsum = floatMatrix.rowSums();
        this.Ksum = this.Kcolsum.sum();
        floatMatrix.addiColumnVector(this.Kcolsum.mul((-1.0f) / fArr.length));
        floatMatrix.addiRowVector(this.Kcolsum.transpose().mul((-1.0f) / fArr.length));
        floatMatrix.addi(this.Ksum / (fArr.length * fArr.length));
        floatMatrix.addi(FloatMatrix.eye(fArr.length).muli(this.reg));
        this.alpha = Solve.solvePositive(floatMatrix, new FloatMatrix(fArr2)).transpose();
    }

    @Override // tberg.murphy.regressor.Regressor
    public float[][] predict(float[][] fArr) {
        FloatMatrix floatMatrix = new FloatMatrix(this.kernelBuilder.build(this.x, fArr));
        CublasUtil.freeAll();
        floatMatrix.addiColumnVector(this.Kcolsum.mul((-1.0f) / this.x.length));
        floatMatrix.addiRowVector(floatMatrix.columnSums().mul((-1.0f) / this.x.length));
        floatMatrix.addi(this.Ksum / (this.x.length * this.x.length));
        return this.alpha.mmul(floatMatrix).transpose().toArray2();
    }
}
