package edu.berkeley.cs.nlp.ocular.model.em;

import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_event;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;
import tberg.murphy.gpu.CudaUtil;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/model/em/JOCLInnerLoop.class */
public class JOCLInnerLoop implements EmissionCacheInnerLoop {
    public static final int GPU_BLOCK_SIZE_X = 1;
    public static final int GPU_ROLL_X = 32;
    public static final int GPU_BLOCK_SIZE_Y = 64;
    public static final int CPU_BLOCK_SIZE_X = 1;
    public static final int CPU_ROLL_X = 8;
    public static final int CPU_BLOCK_SIZE_Y = 1;
    int blockSizeX;
    int rollX;
    int blockSizeY;
    int numThreads;
    int[] templateNumIndices;
    int[] templateIndicesOffsets;
    int maxTemplateWidth;
    int minTemplateWidth;
    cl_context context;
    cl_command_queue queue;
    cl_program program;
    cl_mem d_Ow;
    cl_mem d_Ob;
    cl_mem d_scores;
    cl_mem[] d_Tw;
    cl_mem[] d_Tb;
    cl_kernel[] kernels;

    private static String getString(cl_device_id cl_device_idVar, int i) {
        long[] jArr = new long[1];
        CL.clGetDeviceInfo(cl_device_idVar, i, 0L, null, jArr);
        byte[] bArr = new byte[(int) jArr[0]];
        CL.clGetDeviceInfo(cl_device_idVar, i, bArr.length, Pointer.to(bArr), null);
        return new String(bArr, 0, bArr.length - 1);
    }

    public JOCLInnerLoop(int i) {
        this.numThreads = i;
        CL.setExceptionsEnabled(true);
        int[] iArr = new int[1];
        CL.clGetPlatformIDs(0, null, iArr);
        cl_platform_id[] cl_platform_idVarArr = new cl_platform_id[iArr[0]];
        CL.clGetPlatformIDs(cl_platform_idVarArr.length, cl_platform_idVarArr, null);
        cl_platform_id cl_platform_idVar = cl_platform_idVarArr[0];
        cl_context_properties cl_context_propertiesVar = new cl_context_properties();
        cl_context_propertiesVar.addProperty(4228L, cl_platform_idVar);
        cl_device_id cl_device_idVar = null;
        boolean z = false;
        int[] iArr2 = new int[1];
        CL.clGetDeviceIDs(cl_platform_idVar, 4L, 0, null, iArr2);
        int i2 = iArr2[0];
        cl_device_id[] cl_device_idVarArr = new cl_device_id[i2];
        CL.clGetDeviceIDs(cl_platform_idVar, 4L, i2, cl_device_idVarArr, null);
        for (int i3 = 0; i3 < cl_device_idVarArr.length; i3++) {
            String lowerCase = getString(cl_device_idVarArr[i3], CL.CL_DEVICE_NAME).toLowerCase();
            if (lowerCase.contains("radeon") || lowerCase.contains("nvidia")) {
                cl_device_idVar = cl_device_idVarArr[i3];
                z = true;
                break;
            }
        }
        if (!z) {
            int[] iArr3 = new int[1];
            CL.clGetDeviceIDs(cl_platform_idVar, 2L, 0, null, iArr3);
            int i4 = iArr3[0];
            cl_device_id[] cl_device_idVarArr2 = new cl_device_id[i4];
            CL.clGetDeviceIDs(cl_platform_idVar, 2L, i4, cl_device_idVarArr2, null);
            cl_device_idVar = cl_device_idVarArr2[0];
            z = false;
        }
        if (z) {
            this.blockSizeX = 1;
            this.rollX = 32;
            this.blockSizeY = 64;
        } else {
            this.blockSizeX = 1;
            this.rollX = 8;
            this.blockSizeY = 1;
        }
        System.out.printf("Device name: %s\n", getString(cl_device_idVar, CL.CL_DEVICE_NAME));
        System.out.println("Block size x: " + this.blockSizeX);
        System.out.println("Roll x: " + this.rollX);
        System.out.println("Block size y: " + this.blockSizeY);
        this.context = CL.clCreateContext(cl_context_propertiesVar, 1, new cl_device_id[]{cl_device_idVar}, null, null, null);
        this.queue = CL.clCreateCommandQueue(this.context, cl_device_idVar, 0L, null);
        this.program = CL.clCreateProgramWithSource(this.context, 1, new String[]{kernelSrc()}, null, null);
        CL.clBuildProgram(this.program, 0, null, "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math -cl-no-signed-zeros", null, null);
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public void startup(float[][] fArr, float[][] fArr2, int[] iArr, int[] iArr2, int i, int i2, int i3, int i4) {
        this.templateNumIndices = iArr;
        this.templateIndicesOffsets = iArr2;
        this.maxTemplateWidth = i2;
        this.minTemplateWidth = i;
        int i5 = (i2 - i) + 1;
        this.kernels = new cl_kernel[i5];
        for (int i6 = i; i6 <= i2; i6++) {
            if (iArr[i6 - i] > 0) {
                this.kernels[i6 - i] = CL.clCreateKernel(this.program, "compute_emissions_" + i6, null);
            }
        }
        int ceil = this.blockSizeX * this.rollX * ((int) Math.ceil(i3 / (this.blockSizeX * this.rollX)));
        this.d_Ow = CL.clCreateBuffer(this.context, 1L, 4 * ((ceil + i2) - 1) * 30, null, null);
        this.d_Ob = CL.clCreateBuffer(this.context, 1L, 4 * ((ceil + i2) - 1) * 30, null, null);
        this.d_scores = CL.clCreateBuffer(this.context, 1L, 4 * i3 * i4, null, null);
        this.d_Tw = new cl_mem[i5];
        this.d_Tb = new cl_mem[i5];
        for (int i7 = i; i7 <= i2; i7++) {
            if (iArr[i7 - i] > 0) {
                this.d_Tw[i7 - i] = CL.clCreateBuffer(this.context, 36L, 4 * fArr[i7 - i].length, Pointer.to(fArr[i7 - i]), null);
                this.d_Tb[i7 - i] = CL.clCreateBuffer(this.context, 36L, 4 * fArr2[i7 - i].length, Pointer.to(fArr2[i7 - i]), null);
            }
        }
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public void shutdown() {
        CL.clReleaseMemObject(this.d_Ow);
        CL.clReleaseMemObject(this.d_Ob);
        CL.clReleaseMemObject(this.d_scores);
        for (int i = this.minTemplateWidth; i <= this.maxTemplateWidth; i++) {
            if (this.templateNumIndices[i - this.minTemplateWidth] > 0) {
                CL.clReleaseMemObject(this.d_Tw[i - this.minTemplateWidth]);
                CL.clReleaseMemObject(this.d_Tb[i - this.minTemplateWidth]);
            }
        }
        for (cl_kernel cl_kernelVar : this.kernels) {
            CL.clReleaseKernel(cl_kernelVar);
        }
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public void compute(float[] fArr, float[] fArr2, float[] fArr3, int i) {
        int i2 = (this.maxTemplateWidth - this.minTemplateWidth) + 1;
        int ceil = ((int) Math.ceil(i / (this.blockSizeX * this.rollX))) * this.blockSizeX * this.rollX;
        cl_event[] cl_eventVarArr = {new cl_event(), new cl_event()};
        CL.clEnqueueWriteBuffer(this.queue, this.d_Ow, true, 0L, ((ceil + this.maxTemplateWidth) - 1) * 30 * 4, Pointer.to(CudaUtil.extendWithZeros(fArr2, ((ceil + this.maxTemplateWidth) - 1) * 30)), 0, null, cl_eventVarArr[0]);
        CL.clEnqueueWriteBuffer(this.queue, this.d_Ob, true, 0L, ((ceil + this.maxTemplateWidth) - 1) * 30 * 4, Pointer.to(CudaUtil.extendWithZeros(fArr3, ((ceil + this.maxTemplateWidth) - 1) * 30)), 0, null, cl_eventVarArr[1]);
        cl_event[] cl_eventVarArr2 = new cl_event[i2];
        for (int i3 = this.minTemplateWidth; i3 <= this.maxTemplateWidth; i3++) {
            if (this.templateNumIndices[i3 - this.minTemplateWidth] > 0) {
                int ceil2 = (int) Math.ceil(this.templateNumIndices[i3 - this.minTemplateWidth] / this.blockSizeY);
                cl_kernel cl_kernelVar = this.kernels[i3 - this.minTemplateWidth];
                CL.clSetKernelArg(cl_kernelVar, 0, 4L, Pointer.to(new int[]{this.templateIndicesOffsets[i3 - this.minTemplateWidth] * i}));
                CL.clSetKernelArg(cl_kernelVar, 1, 4L, Pointer.to(new int[]{i}));
                CL.clSetKernelArg(cl_kernelVar, 2, 4L, Pointer.to(new int[]{this.templateNumIndices[i3 - this.minTemplateWidth]}));
                CL.clSetKernelArg(cl_kernelVar, 3, Sizeof.cl_mem, Pointer.to(this.d_Tw[i3 - this.minTemplateWidth]));
                CL.clSetKernelArg(cl_kernelVar, 4, Sizeof.cl_mem, Pointer.to(this.d_Tb[i3 - this.minTemplateWidth]));
                CL.clSetKernelArg(cl_kernelVar, 5, Sizeof.cl_mem, Pointer.to(this.d_Ow));
                CL.clSetKernelArg(cl_kernelVar, 6, Sizeof.cl_mem, Pointer.to(this.d_Ob));
                CL.clSetKernelArg(cl_kernelVar, 7, Sizeof.cl_mem, Pointer.to(this.d_scores));
                cl_eventVarArr2[i3 - this.minTemplateWidth] = new cl_event();
                CL.clEnqueueNDRangeKernel(this.queue, cl_kernelVar, 2, null, new long[]{r0 * this.blockSizeX, ceil2 * this.blockSizeY}, new long[]{this.blockSizeX, this.blockSizeY}, 2, cl_eventVarArr, cl_eventVarArr2[i3 - this.minTemplateWidth]);
            }
        }
        cl_event cl_eventVar = new cl_event();
        CL.clEnqueueReadBuffer(this.queue, this.d_scores, true, 0L, fArr.length * 4, Pointer.to(fArr), cl_eventVarArr2.length, cl_eventVarArr2, cl_eventVar);
        CL.clWaitForEvents(1, new cl_event[]{cl_eventVar});
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public int numOuterThreads() {
        return 1;
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public int numPopulateThreads() {
        return this.numThreads;
    }

    public String kernelSrc() {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 1; i <= 30; i++) {
            stringBuffer.append("__kernel void compute_emissions_" + i + "(__const int scoresOffset, __const int Olength, __const int Tlength, __global float const* __restrict__ Tw, __global float const* __restrict__ Tb, __global float const* __restrict__ Ow, __global float const* __restrict__ Ob, __global float* scores) {\n");
            stringBuffer.append("int Tindex = get_global_id(1);\n");
            stringBuffer.append("if (Tindex < Tlength) {\n");
            for (int i2 = 0; i2 < this.rollX; i2++) {
                stringBuffer.append("float o" + i2 + " = 0;\n");
                stringBuffer.append("float score" + i2 + " = 0;\n");
            }
            stringBuffer.append("for (int i=0; i<" + (30 * i) + "; ++i) {\n");
            stringBuffer.append("float tw = Tw[Tindex * " + (30 * i) + " + i];\n");
            for (int i3 = 0; i3 < this.rollX; i3++) {
                stringBuffer.append("o" + i3 + " = Ow[(get_group_id(0) * " + (this.blockSizeX * this.rollX) + " + get_local_id(0) * " + this.rollX + " + " + i3 + ") * 30 + i];\n");
                stringBuffer.append("score" + i3 + " += o" + i3 + " * tw;\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("for (int i=0; i<" + (30 * i) + "; ++i) {\n");
            stringBuffer.append("float tb = Tb[Tindex * " + (30 * i) + " + i];\n");
            for (int i4 = 0; i4 < this.rollX; i4++) {
                stringBuffer.append("o" + i4 + " = Ob[(get_group_id(0) * " + (this.blockSizeX * this.rollX) + " + get_local_id(0) * " + this.rollX + " + " + i4 + ") * 30 + i];\n");
                stringBuffer.append("score" + i4 + " += o" + i4 + " * tb;\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("int Oindex;\n");
            for (int i5 = 0; i5 < this.rollX; i5++) {
                stringBuffer.append("Oindex = get_group_id(0) * " + (this.blockSizeX * this.rollX) + " + get_local_id(0) * " + this.rollX + " + " + i5 + ";\n");
                stringBuffer.append("if (Oindex < Olength) scores[scoresOffset + Oindex * Tlength + Tindex] = score" + i5 + ";\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
        }
        return stringBuffer.toString();
    }
}
