package tberg.murphy.gpu;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUresult;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import org.jblas.FloatMatrix;
import org.jblas.Solve;
import tberg.murphy.arrays.a;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/gpu/CublasUtil.class */
public class CublasUtil {
    public static final boolean DEBUG_SYNC = false;
    public static LinkedList<Matrix> allocated;
    public static cublasHandle cublasHandle;
    public static CUmodule helperModule;

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/gpu/CublasUtil$Matrix.class */
    public static class Matrix {
        private int rows;
        private int cols;
        private static final int BLOCK_SIZE = 512;
        public static final String kernels = "extern \"C\"\n__global__ void vectorScalarSet(float* A, float alpha, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        A[i] = alpha;\n    }\n}\nextern \"C\"\n__global__ void vectorScalarAdd(const float* __restrict__ A, float* B, float alpha, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = A[i] + alpha;\n    }\n}\nextern \"C\"\n__global__ void vectorLog(const float* __restrict__ A, float* B, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = log(A[i]);\n    }\n}\nextern \"C\"\n__global__ void vectorExp(const float* __restrict__ A, float* B, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = exp(A[i]);\n    }\n}\nextern \"C\"\n__global__ void vectorSign(const float* __restrict__ A, float* B, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = (A[i] > 0.0 ? 1.0 : -1.0);\n    }\n}\nextern \"C\"\n__global__ void vectorAbs(const float* __restrict__ A, float* B, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = abs(A[i]);\n    }\n}\nextern \"C\"\n__global__ void vectorDiv(const float* __restrict__ A, const float* __restrict__ B, float* C, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        C[i] = A[i] / B[i];\n    }\n}\nextern \"C\"\n__global__ void vectorMul(const float* __restrict__ A, const float* __restrict__ B, float* C, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        C[i] = A[i] * B[i];\n    }\n}\nextern \"C\"\n__global__ void vectorMax(const float* __restrict__ A, float* B, float val, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = max(A[i], val);\n    }\n}\nextern \"C\"\n__global__ void vectorMin(const float* __restrict__ A, float* B, float val, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = min(A[i], val);\n    }\n}\nextern \"C\"\n__global__ void vectorPow(const float* __restrict__ A, float* B, float val, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = pow((double) A[i], (double) val);\n    }\n}\nextern \"C\"\n__global__ void vectorSqr(const float* __restrict__ A, float* B, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    float val;\n    if (i < numElements)\n    {\n        val = A[i];\n        B[i] = val*val;\n    }\n}\nextern \"C\"\n__global__ void vectorSqrt(const float* __restrict__ A, float* B, int numElements)\n{\n    int i = blockDim.x * blockIdx.x + threadIdx.x;\n    if (i < numElements)\n    {\n        B[i] = sqrt(A[i]);\n    }\n}\n";
        private boolean dontFree = false;
        private Pointer data_d = new Pointer();

        public Matrix(int i, int i2) {
            this.rows = i;
            this.cols = i2;
            JCuda.cudaMalloc(this.data_d, i * i2 * 4);
            CublasUtil.allocated.add(this);
        }

        public void setDontFree(boolean z) {
            this.dontFree = z;
        }

        public boolean dontFree() {
            return this.dontFree;
        }

        public boolean equals(Object obj) {
            return (obj instanceof Matrix) && this.data_d.equals(((Matrix) obj).data_d);
        }

        public int hashCode() {
            return this.data_d.hashCode();
        }

        public static Matrix build(float[][] fArr) {
            Matrix matrix = new Matrix(fArr.length, fArr[0].length);
            JCublas2.cublasSetMatrix(matrix.rows, matrix.cols, 4, Pointer.to(toColMajor(fArr)), matrix.rows, matrix.data_d, matrix.rows);
            return matrix;
        }

        public static Matrix build(int i, int i2, float[] fArr) {
            Matrix matrix = new Matrix(i, i2);
            JCublas2.cublasSetMatrix(matrix.rows, matrix.cols, 4, Pointer.to(fArr), matrix.rows, matrix.data_d, matrix.rows);
            return matrix;
        }

        public static Matrix rand(int i, int i2, Random random) {
            return build(a.randFloat(i, i2, random));
        }

        public static Matrix ones(int i, int i2) {
            Matrix matrix = new Matrix(i, i2);
            matrix.set(1.0f);
            return matrix;
        }

        public static Matrix zeros(int i, int i2) {
            Matrix matrix = new Matrix(i, i2);
            matrix.zeroi();
            return matrix;
        }

        public static Matrix eye(int i) {
            Matrix zeros = zeros(i, i);
            zeros.diagAddi(1.0f);
            return zeros;
        }

        public boolean isVector() {
            return this.rows == 1 || this.cols == 1;
        }

        public boolean isScalar() {
            return this.rows == 1 && this.cols == 1;
        }

        public int rows() {
            return this.rows;
        }

        public int cols() {
            return this.cols;
        }

        public Matrix copy() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            JCublas2.cublasScopy(CublasUtil.cublasHandle, this.rows * this.cols, this.data_d, 1, matrix.data_d, 1);
            return matrix;
        }

        public Matrix copySubmatrix(int i, int i2, int i3, int i4) {
            Matrix matrix = new Matrix(i2 - i, i4 - i3);
            JCublas2.cublasSetMatrix(matrix.rows, matrix.cols, 4, this.data_d.withByteOffset(((i3 * this.rows) + i) * 4), this.rows, matrix.data_d, matrix.rows);
            return matrix;
        }

        public Matrix setSubmatrix(int i, int i2, Matrix matrix, int i3, int i4, int i5, int i6) {
            JCublas2.cublasSetMatrix(i4 - i3, i6 - i5, 4, matrix.data_d.withByteOffset(((i5 * matrix.rows) + i3) * 4), matrix.rows, this.data_d.withByteOffset(((i2 * this.rows) + i) * 4), this.rows);
            return this;
        }

        public Matrix setSubmatrix(Matrix matrix, int i, int i2) {
            JCublas2.cublasSetMatrix(matrix.rows, matrix.cols, 4, matrix.data_d, matrix.rows, this.data_d.withByteOffset(((i2 * this.rows) + i) * 4), this.rows);
            return this;
        }

        public Matrix setSubmatrix(float[][] fArr, int i, int i2) {
            JCublas2.cublasSetMatrix(fArr.length, fArr[0].length, 4, Pointer.to(toColMajor(fArr)), fArr.length, this.data_d.withByteOffset(((i2 * this.rows) + i) * 4), this.rows);
            return this;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v1, types: [float[], float[][]] */
        public Matrix set(int i, int i2, float f) {
            return setSubmatrix((float[][]) new float[]{new float[]{f}}, i, i2);
        }

        public Matrix copyRow(int i) {
            return copySubmatrix(i, i + 1, 0, this.cols);
        }

        public Matrix copyCol(int i) {
            return copySubmatrix(0, this.rows, i, i + 1);
        }

        public Matrix setRow(int i, Matrix matrix) {
            return setSubmatrix(matrix, i, 0);
        }

        public Matrix setCol(int i, Matrix matrix) {
            return setSubmatrix(matrix, 0, i);
        }

        public Matrix set(float f) {
            scalarSet(this, f);
            return this;
        }

        public float[] toArray() {
            float[] fArr = new float[this.rows * this.cols];
            JCublas2.cublasGetMatrix(this.rows, this.cols, 4, this.data_d, this.rows, Pointer.to(fArr), this.rows);
            return fArr;
        }

        public float[][] toArray2() {
            return fromColMajor(toArray(), this.rows);
        }

        public void free() {
            setDontFree(false);
            if (this.data_d != null) {
                JCuda.cudaFree(this.data_d);
            }
        }

        public Matrix diagAdd(float f) {
            Matrix matrix = new Matrix(1, this.cols);
            matrix.set(f);
            return diagAdd(matrix);
        }

        public Matrix diagAddi(float f) {
            Matrix matrix = new Matrix(1, this.cols);
            matrix.set(f);
            return diagAddi(matrix);
        }

        public Matrix diagAdd(Matrix matrix) {
            return copy().diagAddi(matrix);
        }

        public Matrix diagAddi(Matrix matrix) {
            JCublas2.cublasSaxpy(CublasUtil.cublasHandle, matrix.cols, Pointer.to(new float[]{1.0f}), matrix.data_d, 1, this.data_d, this.rows + 1);
            return this;
        }

        public Matrix rowMul(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            dgmm(this, matrix, matrix2, false);
            return matrix2;
        }

        public Matrix colMul(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            dgmm(this, matrix, matrix2, true);
            return matrix2;
        }

        public Matrix rowMuli(Matrix matrix) {
            dgmm(this, matrix, this, false);
            return this;
        }

        public Matrix colMuli(Matrix matrix) {
            dgmm(this, matrix, this, true);
            return this;
        }

        public Matrix rowDiv(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            Matrix ones = ones(matrix.rows, matrix.cols);
            ones.divi(matrix);
            dgmm(this, ones, matrix2, false);
            return matrix2;
        }

        public Matrix colDiv(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            Matrix ones = ones(matrix.rows, matrix.cols);
            ones.divi(matrix);
            dgmm(this, ones, matrix2, true);
            return matrix2;
        }

        public Matrix rowDivi(Matrix matrix) {
            Matrix ones = ones(matrix.rows, matrix.cols);
            ones.divi(matrix);
            dgmm(this, ones, this, false);
            return this;
        }

        public Matrix colDivi(Matrix matrix) {
            Matrix ones = ones(matrix.rows, matrix.cols);
            ones.divi(matrix);
            dgmm(this, ones, this, true);
            return this;
        }

        public Matrix rowAdd(Matrix matrix) {
            return rowComb(1.0f, matrix);
        }

        public Matrix rowAddi(Matrix matrix) {
            return rowCombi(1.0f, matrix);
        }

        public Matrix rowSub(Matrix matrix) {
            return rowComb(-1.0f, matrix);
        }

        public Matrix rowSubi(Matrix matrix) {
            return rowCombi(-1.0f, matrix);
        }

        public Matrix colAdd(Matrix matrix) {
            return colComb(1.0f, matrix);
        }

        public Matrix colAddi(Matrix matrix) {
            return colCombi(1.0f, matrix);
        }

        public Matrix colSub(Matrix matrix) {
            return colComb(-1.0f, matrix);
        }

        public Matrix colSubi(Matrix matrix) {
            return colCombi(-1.0f, matrix);
        }

        public Matrix rowSum() {
            return ones(1, this.rows).mmul(this);
        }

        public Matrix colSum() {
            return mmul(ones(this.cols, 1));
        }

        public Matrix sub(Matrix matrix) {
            return comb(1.0f, -1.0f, matrix);
        }

        public Matrix subi(Matrix matrix) {
            replaceRef(sub(matrix), this);
            return this;
        }

        public Matrix add(Matrix matrix) {
            return comb(1.0f, 1.0f, matrix);
        }

        public Matrix addi(Matrix matrix) {
            replaceRef(add(matrix), this);
            return this;
        }

        public Matrix rowComb(float f, Matrix matrix) {
            return copy().rowCombi(f, matrix);
        }

        public Matrix rowCombi(float f, Matrix matrix) {
            ger(f, ones(this.rows, 1), matrix, this);
            return this;
        }

        public Matrix colComb(float f, Matrix matrix) {
            return copy().colCombi(f, matrix);
        }

        public Matrix colCombi(float f, Matrix matrix) {
            ger(f, matrix, ones(this.cols, 1), this);
            return this;
        }

        public Matrix comb(float f, float f2, Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            JCublas2.cublasSgeam(CublasUtil.cublasHandle, 0, 0, this.rows, this.cols, Pointer.to(new float[]{f}), this.data_d, this.rows, Pointer.to(new float[]{f2}), matrix.data_d, matrix.rows, matrix2.data_d, matrix2.rows);
            return matrix2;
        }

        public Matrix combi(float f, float f2, Matrix matrix) {
            replaceRef(comb(f, f2, matrix), this);
            return this;
        }

        public Matrix mmul(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, matrix.cols);
            gemm(1.0f, this, matrix, 0.0f, matrix2);
            return matrix2;
        }

        public Matrix mmuli(Matrix matrix) {
            replaceRef(mmul(matrix), this);
            return this;
        }

        public Matrix add(float f) {
            Matrix matrix = new Matrix(this.rows, this.cols);
            scalarAdd(this, f, matrix);
            return matrix;
        }

        public Matrix addi(float f) {
            replaceRef(add(f), this);
            return this;
        }

        public Matrix log() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            log(this, matrix);
            return matrix;
        }

        public Matrix logi() {
            replaceRef(log(), this);
            return this;
        }

        public Matrix exp() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            exp(this, matrix);
            return matrix;
        }

        public Matrix expi() {
            replaceRef(exp(), this);
            return this;
        }

        public Matrix sign() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            sign(this, matrix);
            return matrix;
        }

        public Matrix signi() {
            replaceRef(sign(), this);
            return this;
        }

        public Matrix abs() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            abs(this, matrix);
            return matrix;
        }

        public Matrix absi() {
            replaceRef(abs(), this);
            return this;
        }

        public Matrix mul(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            mul(this, matrix, matrix2);
            return matrix2;
        }

        public Matrix muli(Matrix matrix) {
            replaceRef(mul(matrix), this);
            return this;
        }

        public Matrix mul(float f) {
            return copy().muli(f);
        }

        public Matrix muli(float f) {
            JCublas2.cublasSscal(CublasUtil.cublasHandle, this.rows * this.cols, Pointer.to(new float[]{f}), this.data_d, 1);
            return this;
        }

        public Matrix div(Matrix matrix) {
            Matrix matrix2 = new Matrix(this.rows, this.cols);
            div(this, matrix, matrix2);
            return matrix2;
        }

        public Matrix divi(Matrix matrix) {
            replaceRef(div(matrix), this);
            return this;
        }

        public Matrix max(float f) {
            Matrix matrix = new Matrix(this.rows, this.cols);
            max(this, matrix, f);
            return matrix;
        }

        public Matrix maxi(float f) {
            replaceRef(max(f), this);
            return this;
        }

        public Matrix min(float f) {
            Matrix matrix = new Matrix(this.rows, this.cols);
            min(this, matrix, f);
            return matrix;
        }

        public Matrix mini(float f) {
            replaceRef(min(f), this);
            return this;
        }

        public Matrix pow(float f) {
            Matrix matrix = new Matrix(this.rows, this.cols);
            pow(this, matrix, f);
            return matrix;
        }

        public Matrix powi(float f) {
            replaceRef(pow(f), this);
            return this;
        }

        public Matrix sqr() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            sqr(this, matrix);
            return matrix;
        }

        public Matrix sqri() {
            replaceRef(sqr(), this);
            return this;
        }

        public Matrix sqrt() {
            Matrix matrix = new Matrix(this.rows, this.cols);
            sqrt(this, matrix);
            return matrix;
        }

        public Matrix sqrti() {
            replaceRef(sqrt(), this);
            return this;
        }

        public Matrix transpose() {
            Matrix matrix = new Matrix(this.cols, this.rows);
            JCublas2.cublasSgeam(CublasUtil.cublasHandle, 1, 1, this.cols, this.rows, Pointer.to(new float[]{1.0f}), this.data_d, this.rows, Pointer.to(new float[]{0.0f}), new Pointer(), this.rows, matrix.data_d, matrix.rows);
            return matrix;
        }

        public Matrix transposei() {
            if (!isScalar()) {
                if (isVector()) {
                    int i = this.rows;
                    this.rows = this.cols;
                    this.cols = i;
                } else {
                    replaceRef(transpose(), this);
                }
            }
            return this;
        }

        public Matrix zeroi() {
            JCublas2.cublasSgeam(CublasUtil.cublasHandle, 0, 0, this.rows, this.cols, Pointer.to(new float[]{0.0f}), new Pointer(), this.rows, Pointer.to(new float[]{0.0f}), new Pointer(), this.rows, this.data_d, this.rows);
            return this;
        }

        public float norm1() {
            float[] fArr = new float[1];
            JCublas2.cublasSasum(CublasUtil.cublasHandle, this.rows * this.cols, this.data_d, 1, Pointer.to(fArr));
            return fArr[0];
        }

        public float norm2() {
            float[] fArr = new float[1];
            JCublas2.cublasSnrm2(CublasUtil.cublasHandle, this.rows * this.cols, this.data_d, 1, Pointer.to(fArr));
            return fArr[0];
        }

        public float distance1(Matrix matrix) {
            return comb(1.0f, -1.0f, matrix).norm1();
        }

        public float distance2(Matrix matrix) {
            return comb(1.0f, -1.0f, matrix).norm2();
        }

        public static List<Matrix> invert(List<Matrix> list) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < list.size(); i++) {
                arrayList.add(new Matrix(list.get(0).rows, list.get(0).cols));
            }
            getrfGetriBatched(list, arrayList);
            return arrayList;
        }

        public static List<Matrix> mmul(List<Matrix> list, List<Matrix> list2) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < list.size(); i++) {
                arrayList.add(new Matrix(list.get(0).rows, list2.get(0).cols));
            }
            gemmBatched(1.0f, list, list2, 0.0f, arrayList);
            return arrayList;
        }

        private static float[] toColMajor(float[][] fArr) {
            int length = fArr.length;
            int length2 = fArr[0].length;
            float[] fArr2 = new float[length * length2];
            int i = 0;
            for (int i2 = 0; i2 < length2; i2++) {
                for (float[] fArr3 : fArr) {
                    fArr2[i] = fArr3[i2];
                    i++;
                }
            }
            return fArr2;
        }

        private static float[][] fromColMajor(float[] fArr, int i) {
            int length = fArr.length / i;
            float[][] fArr2 = new float[i][length];
            int i2 = 0;
            for (int i3 = 0; i3 < length; i3++) {
                for (int i4 = 0; i4 < i; i4++) {
                    fArr2[i4][i3] = fArr[i2];
                    i2++;
                }
            }
            return fArr2;
        }

        private static void replaceRef(Matrix matrix, Matrix matrix2) {
            matrix2.free();
            matrix2.rows = matrix.rows;
            matrix2.cols = matrix.cols;
            matrix2.data_d = matrix.data_d;
        }

        private static void getrfGetriBatched(List<Matrix> list, List<Matrix> list2) {
            Pointer[] pointerArr = new Pointer[list.size()];
            Pointer[] pointerArr2 = new Pointer[list2.size()];
            for (int i = 0; i < list.size(); i++) {
                pointerArr[i] = list.get(i).data_d;
                pointerArr2[i] = list2.get(i).data_d;
            }
            Pointer pointer = new Pointer();
            JCuda.cudaMalloc(pointer, list.size() * Sizeof.POINTER);
            JCuda.cudaMemcpy(pointer, Pointer.to(pointerArr), list.size() * Sizeof.POINTER, 1);
            Pointer pointer2 = new Pointer();
            JCuda.cudaMalloc(pointer2, list2.size() * Sizeof.POINTER);
            JCuda.cudaMemcpy(pointer2, Pointer.to(pointerArr2), list2.size() * Sizeof.POINTER, 1);
            Pointer pointer3 = new Pointer();
            JCuda.cudaMalloc(pointer3, list.size() * 4);
            Pointer pointer4 = new Pointer();
            JCuda.cudaMalloc(pointer4, list.get(0).rows * list.size() * 4);
            JCublas2.cublasSgetrfBatched(CublasUtil.cublasHandle, list.get(0).rows, pointer, list.get(0).rows, pointer4, pointer3, list.size());
            JCublas2.cublasSgetriBatched(CublasUtil.cublasHandle, list.get(0).rows, pointer, list.get(0).rows, pointer4, pointer2, list2.get(0).rows, pointer3, list.size());
            JCuda.cudaFree(pointer);
            JCuda.cudaFree(pointer2);
            JCuda.cudaFree(pointer3);
            JCuda.cudaFree(pointer4);
        }

        private static void gemmBatched(float f, List<Matrix> list, List<Matrix> list2, float f2, List<Matrix> list3) {
            Pointer[] pointerArr = new Pointer[list.size()];
            Pointer[] pointerArr2 = new Pointer[list2.size()];
            Pointer[] pointerArr3 = new Pointer[list3.size()];
            for (int i = 0; i < list.size(); i++) {
                pointerArr[i] = list.get(i).data_d;
                pointerArr2[i] = list2.get(i).data_d;
                pointerArr3[i] = list3.get(i).data_d;
            }
            Pointer pointer = new Pointer();
            JCuda.cudaMalloc(pointer, list.size() * Sizeof.POINTER);
            JCuda.cudaMemcpy(pointer, Pointer.to(pointerArr), list.size() * Sizeof.POINTER, 1);
            Pointer pointer2 = new Pointer();
            JCuda.cudaMalloc(pointer2, list2.size() * Sizeof.POINTER);
            JCuda.cudaMemcpy(pointer2, Pointer.to(pointerArr2), list2.size() * Sizeof.POINTER, 1);
            Pointer pointer3 = new Pointer();
            JCuda.cudaMalloc(pointer3, list3.size() * Sizeof.POINTER);
            JCuda.cudaMemcpy(pointer3, Pointer.to(pointerArr3), list3.size() * Sizeof.POINTER, 1);
            JCublas2.cublasSgemmBatched(CublasUtil.cublasHandle, 0, 0, list3.get(0).rows, list3.get(0).cols, list2.get(0).rows, Pointer.to(new float[]{f}), pointer, list.get(0).rows, pointer2, list2.get(0).rows, Pointer.to(new float[]{f2}), pointer3, list3.get(0).rows, list.size());
            JCuda.cudaFree(pointer);
            JCuda.cudaFree(pointer2);
            JCuda.cudaFree(pointer3);
        }

        private static void gemm(float f, Matrix matrix, Matrix matrix2, float f2, Matrix matrix3) {
            JCublas2.cublasSgemm(CublasUtil.cublasHandle, 0, 0, matrix3.rows, matrix3.cols, matrix2.rows, Pointer.to(new float[]{f}), matrix.data_d, matrix.rows, matrix2.data_d, matrix2.rows, Pointer.to(new float[]{f2}), matrix3.data_d, matrix3.rows);
        }

        private static void dgmm(Matrix matrix, Matrix matrix2, Matrix matrix3, boolean z) {
            JCublas2.cublasSdgmm(CublasUtil.cublasHandle, z ? 0 : 1, matrix.rows, matrix.cols, matrix.data_d, matrix.rows, matrix2.data_d, 1, matrix3.data_d, matrix3.rows);
        }

        private static void ger(float f, Matrix matrix, Matrix matrix2, Matrix matrix3) {
            JCublas2.cublasSger(CublasUtil.cublasHandle, matrix3.rows, matrix3.cols, Pointer.to(new float[]{f}), matrix.data_d, 1, matrix2.data_d, 1, matrix3.data_d, matrix3.rows);
        }

        private static void scalarSet(Matrix matrix, float f) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorScalarSet");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(new float[]{f}), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void scalarAdd(Matrix matrix, float f, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorScalarAdd");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new float[]{f}), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void log(Matrix matrix, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorLog");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void exp(Matrix matrix, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorExp");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void sign(Matrix matrix, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorSign");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void abs(Matrix matrix, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorAbs");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void div(Matrix matrix, Matrix matrix2, Matrix matrix3) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorDiv");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(matrix3.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void mul(Matrix matrix, Matrix matrix2, Matrix matrix3) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorMul");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(matrix3.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void max(Matrix matrix, Matrix matrix2, float f) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorMax");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new float[]{f}), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void min(Matrix matrix, Matrix matrix2, float f) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorMin");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new float[]{f}), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void pow(Matrix matrix, Matrix matrix2, float f) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorPow");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new float[]{f}), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void sqr(Matrix matrix, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorSqr");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }

        private static void sqrt(Matrix matrix, Matrix matrix2) {
            int i = matrix.rows * matrix.cols;
            CUfunction cUfunction = new CUfunction();
            JCudaDriver.cuModuleGetFunction(cUfunction, CublasUtil.helperModule, "vectorSqrt");
            Pointer pointer = Pointer.to(Pointer.to(matrix.data_d), Pointer.to(matrix2.data_d), Pointer.to(new int[]{i}));
            int min = Math.min(i, BLOCK_SIZE);
            JCudaDriver.cuLaunchKernel(cUfunction, (int) Math.ceil(i / min), 1, 1, min, 1, 1, 0, null, pointer, null);
        }
    }

    public static void startup(int i) {
        CudaUtil.startup(i);
        cublasHandle = new cublasHandle();
        JCublas2.cublasCreate(cublasHandle);
        helperModule = CudaUtil.compileAndLoad("la_helper_funs", Matrix.kernels, true);
        allocated = new LinkedList<>();
        JCublas2.cublasSetAtomicsMode(cublasHandle, 1);
        JCublas2.cublasSetPointerMode(cublasHandle, 0);
    }

    public static void shutdown() {
        freeAll(true);
        JCublas2.cublasDestroy(cublasHandle);
        CudaUtil.shutdown();
    }

    public static void freeAll() {
        freeAll(false);
    }

    public static void freeAll(boolean z) {
        LinkedList<Matrix> linkedList = new LinkedList<>();
        while (!allocated.isEmpty()) {
            Matrix poll = allocated.poll();
            if (z || !poll.dontFree) {
                poll.free();
            } else {
                linkedList.add(poll);
            }
        }
        allocated = linkedList;
    }

    public static void freeAllBut(Matrix... matrixArr) {
        HashSet hashSet = new HashSet();
        for (Matrix matrix : matrixArr) {
            hashSet.add(matrix);
        }
        freeAllBut(hashSet);
    }

    public static void freeAllBut(Collection<Matrix> collection) {
        LinkedList<Matrix> linkedList = new LinkedList<>();
        while (!allocated.isEmpty()) {
            Matrix poll = allocated.poll();
            if (collection.contains(poll) || poll.dontFree) {
                linkedList.add(poll);
            } else {
                poll.free();
            }
        }
        allocated = linkedList;
    }

    /* JADX WARN: Type inference failed for: r0v26, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r0v28, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r0v30, types: [float[], float[][]] */
    public static void main(String[] strArr) {
        startup(0);
        Random random = new Random(1L);
        float[][] randFloat = a.randFloat(2, 3, random);
        float[][] randFloat2 = a.randFloat(2, 3, random);
        System.out.println("\n\nCPU");
        FloatMatrix floatMatrix = new FloatMatrix(randFloat);
        System.out.println(a.toString(floatMatrix.transpose().toArray2()));
        System.out.println(a.toString(floatMatrix.add(new FloatMatrix(randFloat2).mul(-2.0f)).toArray2()));
        System.out.println("\n\nGPU");
        Matrix build = Matrix.build(randFloat);
        System.out.println(a.toString(build.transpose().toArray2()));
        Matrix build2 = Matrix.build(randFloat2);
        Matrix comb = build.comb(1.0f, -2.0f, build2);
        System.out.println(a.toString(comb.toArray2()));
        build.free();
        build2.free();
        comb.free();
        ?? r0 = {new float[]{1.0f, 2.0f, 3.0f}, new float[]{-1.0f, -2.5f, 1.0f}};
        ?? r02 = {new float[]{4.0f, 0.0f, -1.0f}, new float[]{0.0f, -2.5f, 1.0f}, new float[]{9.0f, -10.0f, -0.5f}};
        ?? r03 = {new float[]{1.0f, 2.0f, 3.0f}, new float[]{1.0f, 2.0f, 3.0f}};
        System.out.println("\n\nCPU");
        FloatMatrix floatMatrix2 = new FloatMatrix((float[][]) r0);
        FloatMatrix mul = floatMatrix2.mmul(new FloatMatrix((float[][]) r02).transpose()).add(FloatMatrix.ones(2, 3)).muli(2.0f).mul(new FloatMatrix((float[][]) r03));
        mul.maxi(-68.0f);
        System.out.println(a.toString(mul.toArray2()));
        System.out.println(mul.norm1());
        System.out.println(mul.norm2());
        FloatMatrix div = mul.div(floatMatrix2);
        System.out.println(a.toString(div.toArray2()));
        System.out.println(div.norm1());
        System.out.println(div.norm2());
        System.out.println("\n\nGPU");
        Matrix build3 = Matrix.build(r0);
        Matrix build4 = Matrix.build(r02);
        Matrix build5 = Matrix.build(r03);
        Matrix mul2 = build3.mmul(build4.transpose()).add(Matrix.ones(2, 3)).muli(2.0f).mul(build5);
        mul2.maxi(-68.0f);
        System.out.println(a.toString(mul2.toArray2()));
        System.out.println(mul2.norm1());
        System.out.println(mul2.norm2());
        Matrix div2 = mul2.div(build3);
        System.out.println(a.toString(div2.toArray2()));
        System.out.println(div2.norm1());
        System.out.println(div2.norm2());
        build3.free();
        build4.free();
        build5.free();
        mul2.free();
        System.out.println("\n\n" + a.toString(Matrix.ones(CUresult.CUDA_ERROR_INVALID_SOURCE, 3).toArray2()));
        System.out.println("\n\n" + a.toString(Matrix.ones(CUresult.CUDA_ERROR_INVALID_SOURCE, 3).toArray2()));
        System.out.println("\n\n" + a.toString(Matrix.ones(CUresult.CUDA_ERROR_INVALID_SOURCE, 3).toArray2()));
        Matrix ones = Matrix.ones(3, 3);
        ones.muli(4.0f);
        System.out.println(a.toString(ones.toArray2()));
        System.out.println(a.toString(ones.sqrt().toArray2()));
        System.out.println(a.toString(ones.toArray2()));
        ones.sqrti();
        System.out.println(a.toString(ones.toArray2()));
        ones.free();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 10; i++) {
            arrayList.add(Matrix.build(a.randFloat(3, 3, random)));
            arrayList2.add(Matrix.build(a.randFloat(3, 3, random)));
        }
        System.out.println("CPU:");
        for (int i2 = 0; i2 < 10; i2++) {
            System.out.println(a.toString(new FloatMatrix(((Matrix) arrayList.get(i2)).toArray2()).mmul(new FloatMatrix(((Matrix) arrayList2.get(i2)).toArray2())).toArray2()));
        }
        System.out.println("GPU:");
        List<Matrix> mmul = Matrix.mmul(arrayList, arrayList2);
        for (int i3 = 0; i3 < 10; i3++) {
            System.out.println(a.toString(mmul.get(i3).toArray2()));
        }
        System.out.println("CPU:");
        for (int i4 = 0; i4 < 10; i4++) {
            FloatMatrix floatMatrix3 = new FloatMatrix(((Matrix) arrayList.get(i4)).toArray2());
            System.out.println(a.toString(Solve.solve(floatMatrix3, FloatMatrix.eye(floatMatrix3.rows)).toArray2()));
        }
        System.out.println("GPU:");
        List<Matrix> invert = Matrix.invert(arrayList);
        for (int i5 = 0; i5 < 10; i5++) {
            System.out.println(a.toString(invert.get(i5).toArray2()));
        }
        System.out.println("\n\nCPU");
        FloatMatrix floatMatrix4 = new FloatMatrix((float[][]) r0);
        System.out.println(a.toString(floatMatrix4.rowSums().toArray2()));
        System.out.println(a.toString(floatMatrix4.columnSums().toArray2()));
        System.out.println("\n\nGPU");
        Matrix build6 = Matrix.build(r0);
        System.out.println(a.toString(build6.colSum().toArray2()));
        System.out.println(a.toString(build6.rowSum().toArray2()));
        build6.free();
        System.out.println("\n\nCPU");
        FloatMatrix floatMatrix5 = new FloatMatrix((float[][]) r0);
        FloatMatrix ones2 = FloatMatrix.ones(2, 1);
        ones2.put(0, 0, 2.0f);
        FloatMatrix ones3 = FloatMatrix.ones(1, 3);
        ones3.put(0, 0, 2.0f);
        System.out.println(a.toString(floatMatrix5.addColumnVector(ones2).toArray2()));
        System.out.println(a.toString(floatMatrix5.addRowVector(ones3).toArray2()));
        System.out.println("\n\nGPU");
        Matrix build7 = Matrix.build(r0);
        float[][] onesFloat = a.onesFloat(2, 1);
        onesFloat[0][0] = 2.0f;
        float[][] onesFloat2 = a.onesFloat(1, 3);
        onesFloat2[0][0] = 2.0f;
        System.out.println(a.toString(build7.colAdd(Matrix.build(onesFloat)).toArray2()));
        System.out.println(a.toString(build7.rowAdd(Matrix.build(onesFloat2)).toArray2()));
        build7.free();
        System.out.println("\n\nGPU");
        Matrix rand = Matrix.rand(5, 7, random);
        System.out.println(a.toString(rand.toArray2()));
        System.out.println(a.toString(rand.copySubmatrix(1, 3, 2, 4).toArray2()));
        rand.setSubmatrix(Matrix.ones(2, 2), 1, 2);
        System.out.println(a.toString(rand.toArray2()));
        rand.setSubmatrix(a.onesFloat(2, 2), 1, 0);
        System.out.println(a.toString(rand.toArray2()));
        rand.set(4, 3, 5.0f);
        System.out.println(a.toString(rand.toArray2()));
        Matrix rand2 = Matrix.rand(2, 3, random);
        System.out.println(a.toString(rand2.toArray2()));
        rand.setSubmatrix(1, 1, rand2, 1, 2, 1, 3);
        System.out.println(a.toString(rand.toArray2()));
        rand.free();
        System.out.println("\n\nCPU");
        FloatMatrix floatMatrix6 = new FloatMatrix((float[][]) r0);
        FloatMatrix ones4 = FloatMatrix.ones(2, 1);
        ones4.put(0, 0, 2.0f);
        FloatMatrix ones5 = FloatMatrix.ones(1, 3);
        ones5.put(0, 0, 2.0f);
        System.out.println(a.toString(floatMatrix6.toArray2()));
        System.out.println(a.toString(floatMatrix6.mulColumnVector(ones4).toArray2()));
        System.out.println(a.toString(floatMatrix6.mulRowVector(ones5).toArray2()));
        System.out.println(a.toString(floatMatrix6.divColumnVector(ones4).toArray2()));
        System.out.println(a.toString(floatMatrix6.divRowVector(ones5).toArray2()));
        System.out.println("\n\nGPU");
        Matrix build8 = Matrix.build(r0);
        float[][] onesFloat3 = a.onesFloat(2, 1);
        onesFloat3[0][0] = 2.0f;
        float[][] onesFloat4 = a.onesFloat(1, 3);
        onesFloat4[0][0] = 2.0f;
        System.out.println(a.toString(build8.toArray2()));
        System.out.println(a.toString(build8.colMul(Matrix.build(onesFloat3)).toArray2()));
        System.out.println(a.toString(build8.rowMul(Matrix.build(onesFloat4)).toArray2()));
        System.out.println(a.toString(build8.colDiv(Matrix.build(onesFloat3)).toArray2()));
        System.out.println(a.toString(build8.rowDiv(Matrix.build(onesFloat4)).toArray2()));
        build8.rowDivi(Matrix.build(onesFloat4));
        System.out.println(a.toString(build8.toArray2()));
        build8.free();
        System.out.println("\n\nGPU");
        Matrix ones6 = Matrix.ones(5, 5);
        System.out.println(a.toString(ones6.toArray2()));
        System.out.println(a.toString(ones6.diagAdd(10.0f).toArray2()));
        ones6.diagAddi(1.0f);
        System.out.println(a.toString(ones6.toArray2()));
        ones6.free();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i6 = 0; i6 < 3; i6++) {
            arrayList3.add(Matrix.build(a.randFloat(3, 3, random)));
            arrayList4.add(Matrix.build(a.randFloat(3, 3, random)));
        }
        System.out.println("CPU:");
        for (int i7 = 0; i7 < 3; i7++) {
            FloatMatrix floatMatrix7 = new FloatMatrix(((Matrix) arrayList3.get(i7)).toArray2());
            System.out.println(a.toString(Solve.solve(floatMatrix7, FloatMatrix.eye(floatMatrix7.rows)).toArray2()));
        }
        System.out.println("GPU:");
        List<Matrix> invert2 = Matrix.invert(arrayList3);
        for (int i8 = 0; i8 < 3; i8++) {
            System.out.println(a.toString(invert2.get(i8).toArray2()));
        }
        System.out.println("\n\nGPU");
        Matrix ones7 = Matrix.ones(5, 5);
        ones7.muli(-0.2f);
        System.out.println(a.toString(ones7.toArray2()));
        ones7.powi(1.0f);
        System.out.println(a.toString(ones7.toArray2()));
        ones7.free();
        shutdown();
    }
}
