package edu.berkeley.cs.nlp.ocular.model.em;

import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;
import tberg.murphy.gpu.CudaUtil;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/model/em/CUDAInnerLoop.class */
public class CUDAInnerLoop implements EmissionCacheInnerLoop {
    public static final int BLOCK_SIZE_X = 16;
    public static final int ROLL_X = 17;
    public static final int BLOCK_SIZE_Y = 64;
    int numThreads;
    float[][] whiteTemplates;
    float[][] blackTemplates;
    int[] templateNumIndices;
    int[] templateIndicesOffsets;
    int maxTemplateWidth;
    int minTemplateWidth;
    int totalTemplateNumIndices;
    CUmodule cudaModule;
    CUdeviceptr d_Ow;
    CUdeviceptr d_Ob;
    CUdeviceptr d_scores;
    CUdeviceptr[] d_Tw;
    CUdeviceptr[] d_Tb;

    public CUDAInnerLoop(int i, int i2) {
        this.numThreads = i;
        CudaUtil.startup(i2);
        this.cudaModule = CudaUtil.compileAndLoad("/tmp/emission_cache_kernel", kernelSrcShared(), true);
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public void startup(float[][] fArr, float[][] fArr2, int[] iArr, int[] iArr2, int i, int i2, int i3, int i4) {
        this.whiteTemplates = fArr;
        this.blackTemplates = fArr2;
        this.templateNumIndices = iArr;
        this.templateIndicesOffsets = iArr2;
        this.maxTemplateWidth = i2;
        this.minTemplateWidth = i;
        this.totalTemplateNumIndices = i4;
        int i5 = (i2 - i) + 1;
        int ceil = 272 * ((int) Math.ceil(i3 / 272.0d));
        this.d_Ow = new CUdeviceptr();
        JCudaDriver.cuMemAlloc(this.d_Ow, ((ceil + i2) - 1) * 30 * 4);
        this.d_Ob = new CUdeviceptr();
        JCudaDriver.cuMemAlloc(this.d_Ob, ((ceil + i2) - 1) * 30 * 4);
        this.d_scores = new CUdeviceptr();
        JCudaDriver.cuMemAlloc(this.d_scores, i3 * i4 * 4);
        this.d_Tw = new CUdeviceptr[i5];
        this.d_Tb = new CUdeviceptr[i5];
        for (int i6 = i; i6 <= i2; i6++) {
            if (iArr[i6 - i] > 0) {
                this.d_Tw[i6 - i] = new CUdeviceptr();
                JCudaDriver.cuMemAlloc(this.d_Tw[i6 - i], fArr[i6 - i].length * 4);
                JCudaDriver.cuMemcpyHtoD(this.d_Tw[i6 - i], Pointer.to(fArr[i6 - i]), fArr[i6 - i].length * 4);
                this.d_Tb[i6 - i] = new CUdeviceptr();
                JCudaDriver.cuMemAlloc(this.d_Tb[i6 - i], fArr2[i6 - i].length * 4);
                JCudaDriver.cuMemcpyHtoD(this.d_Tb[i6 - i], Pointer.to(fArr2[i6 - i]), fArr2[i6 - i].length * 4);
            }
        }
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public void shutdown() {
        JCudaDriver.cuMemFree(this.d_Ow);
        JCudaDriver.cuMemFree(this.d_Ob);
        JCudaDriver.cuMemFree(this.d_scores);
        for (int i = this.minTemplateWidth; i <= this.maxTemplateWidth; i++) {
            if (this.templateNumIndices[i - this.minTemplateWidth] > 0) {
                JCudaDriver.cuMemFree(this.d_Tw[i - this.minTemplateWidth]);
                JCudaDriver.cuMemFree(this.d_Tb[i - this.minTemplateWidth]);
            }
        }
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public void compute(float[] fArr, float[] fArr2, float[] fArr3, int i) {
        int ceil = (int) Math.ceil(i / 272.0d);
        int i2 = ceil * 272;
        JCudaDriver.cuMemcpyHtoD(this.d_Ow, Pointer.to(CudaUtil.extendWithZeros(fArr2, ((i2 + this.maxTemplateWidth) - 1) * 30)), ((i2 + this.maxTemplateWidth) - 1) * 30 * 4);
        JCudaDriver.cuMemcpyHtoD(this.d_Ob, Pointer.to(CudaUtil.extendWithZeros(fArr3, ((i2 + this.maxTemplateWidth) - 1) * 30)), ((i2 + this.maxTemplateWidth) - 1) * 30 * 4);
        for (int i3 = this.minTemplateWidth; i3 <= this.maxTemplateWidth; i3++) {
            if (this.templateNumIndices[i3 - this.minTemplateWidth] > 0) {
                CUfunction cUfunction = new CUfunction();
                JCudaDriver.cuModuleGetFunction(cUfunction, this.cudaModule, "compute_emissions_" + i3);
                JCudaDriver.cuFuncSetCacheConfig(cUfunction, 1);
                JCudaDriver.cuFuncSetSharedMemConfig(cUfunction, 1);
                JCudaDriver.cuLaunchKernel(cUfunction, ceil, (int) Math.ceil(this.templateNumIndices[i3 - this.minTemplateWidth] / 64.0d), 1, 16, 64, 1, 0, null, Pointer.to(Pointer.to(new int[]{this.templateIndicesOffsets[i3 - this.minTemplateWidth] * i}), Pointer.to(new int[]{i}), Pointer.to(new int[]{this.templateNumIndices[i3 - this.minTemplateWidth]}), Pointer.to(this.d_Tw[i3 - this.minTemplateWidth]), Pointer.to(this.d_Tb[i3 - this.minTemplateWidth]), Pointer.to(this.d_Ow), Pointer.to(this.d_Ob), Pointer.to(this.d_scores)), null);
            }
        }
        JCudaDriver.cuMemcpyDtoH(Pointer.to(fArr), this.d_scores, i * this.totalTemplateNumIndices * 4);
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public int numOuterThreads() {
        return 1;
    }

    @Override // edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop
    public int numPopulateThreads() {
        return this.numThreads;
    }

    public String kernelSrcShared() {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 1; i <= 30; i++) {
            stringBuffer.append("extern \"C\"\n");
            stringBuffer.append("__global__ void compute_emissions_" + i + "(int scoresOffset, int Olength, int Tlength, float const* __restrict__ Tw, float const* __restrict__ Tb, float const* __restrict__ Ow, float const* __restrict__ Ob, float* scores) {\n");
            stringBuffer.append("__shared__ float sO[" + ((272 + (i - 1)) * 30) + "];\n");
            stringBuffer.append("int sharedIndex = threadIdx.x * 64 + threadIdx.y;\n");
            stringBuffer.append("if (sharedIndex < " + (272 + (i - 1)) + ") {\n");
            stringBuffer.append("for (int i=0; i<30; ++i) {\n");
            stringBuffer.append("sO[sharedIndex * 30 + i] = Ow[(blockIdx.x * 272 + sharedIndex) * 30 + i];\n");
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
            stringBuffer.append("__syncthreads();\n");
            for (int i2 = 0; i2 < 17; i2++) {
                stringBuffer.append("float score" + i2 + " = 0;\n");
            }
            stringBuffer.append("int Tindex = blockIdx.y * 64 + threadIdx.y;\n");
            stringBuffer.append("if (Tindex < Tlength) {\n");
            stringBuffer.append("for (int i=0; i<" + (30 * i) + "; ++i) {\n");
            stringBuffer.append("float tw = Tw[Tindex * " + (30 * i) + " + i];\n");
            for (int i3 = 0; i3 < 17; i3++) {
                stringBuffer.append("score" + i3 + " = __fmaf_ru(sO[(threadIdx.x * 17 + " + i3 + ") * 30 + i], tw, score" + i3 + ");\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
            stringBuffer.append("__syncthreads();\n");
            stringBuffer.append("if (sharedIndex < " + (272 + (i - 1)) + ") {\n");
            stringBuffer.append("for (int i=0; i<30; ++i) {\n");
            stringBuffer.append("sO[sharedIndex * 30 + i] = Ob[(blockIdx.x * 272 + sharedIndex) * 30 + i];\n");
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
            stringBuffer.append("__syncthreads();\n");
            stringBuffer.append("if (Tindex < Tlength) {\n");
            stringBuffer.append("for (int i=0; i<" + (30 * i) + "; ++i) {\n");
            stringBuffer.append("float tb = Tb[Tindex * " + (30 * i) + " + i];\n");
            for (int i4 = 0; i4 < 17; i4++) {
                stringBuffer.append("score" + i4 + " = __fmaf_ru(sO[(threadIdx.x * 17 + " + i4 + ") * 30 + i], tb, score" + i4 + ");\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("int Oindex;\n");
            for (int i5 = 0; i5 < 17; i5++) {
                stringBuffer.append("Oindex = blockIdx.x * 272 + threadIdx.x * 17 + " + i5 + ";\n");
                stringBuffer.append("if (Oindex < Olength) scores[scoresOffset + Oindex * Tlength + Tindex] = score" + i5 + ";\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
        }
        return stringBuffer.toString();
    }

    public String kernelSrcPrivate() {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 1; i <= 30; i++) {
            stringBuffer.append("extern \"C\"\n");
            stringBuffer.append("__global__ void compute_emissions_" + i + "(int scoresOffset, int Olength, int Tlength, float const* __restrict__ Tw, float const* __restrict__ Tb, float const* __restrict__ Ow, float const* __restrict__ Ob, float* scores) {\n");
            stringBuffer.append("int Tindex = blockIdx.y * 64 + threadIdx.y;\n");
            stringBuffer.append("if (Tindex < Tlength) {\n");
            stringBuffer.append("float pO[" + ((17 + (i - 1)) * 30) + "];\n");
            stringBuffer.append("for (int r=0; r<" + (17 + (i - 1)) + "; ++r) {\n");
            stringBuffer.append("for (int i=0; i<30; ++i) {\n");
            stringBuffer.append("pO[r * 30 + i] = Ow[(blockIdx.x * 272 + threadIdx.x * 17 + r) * 30 + i];\n");
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
            for (int i2 = 0; i2 < 17; i2++) {
                stringBuffer.append("float score" + i2 + " = 0;\n");
            }
            stringBuffer.append("for (int i=0; i<" + (30 * i) + "; ++i) {\n");
            stringBuffer.append("float tw = Tw[Tindex * " + (30 * i) + " + i];\n");
            for (int i3 = 0; i3 < 17; i3++) {
                stringBuffer.append("score" + i3 + " = __fmaf_ru(pO[" + (i3 * 30) + " + i], tw, score" + i3 + ");\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("for (int r=0; r<" + (17 + (i - 1)) + "; ++r) {\n");
            stringBuffer.append("for (int i=0; i<30; ++i) {\n");
            stringBuffer.append("pO[r * 30 + i] = Ob[(blockIdx.x * 272 + threadIdx.x * 17 + r) * 30 + i];\n");
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
            stringBuffer.append("for (int i=0; i<" + (30 * i) + "; ++i) {\n");
            stringBuffer.append("float tb = Tb[Tindex * " + (30 * i) + " + i];\n");
            for (int i4 = 0; i4 < 17; i4++) {
                stringBuffer.append("score" + i4 + " = __fmaf_ru(pO[" + (i4 * 30) + " + i], tb, score" + i4 + ");\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("int Oindex;\n");
            for (int i5 = 0; i5 < 17; i5++) {
                stringBuffer.append("Oindex = blockIdx.x * 272 + threadIdx.x * 17 + " + i5 + ";\n");
                stringBuffer.append("if (Oindex < Olength) scores[scoresOffset + Oindex * Tlength + Tindex] = score" + i5 + ";\n");
            }
            stringBuffer.append("}\n");
            stringBuffer.append("}\n");
        }
        return stringBuffer.toString();
    }
}
