package tberg.murphy.gpu;

import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import org.jocl.CL;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/gpu/CudaUtil.class */
public class CudaUtil {
    public static CUdevice device;
    public static CUcontext context;

    public static void startup(int i) {
        JCudaDriver.setExceptionsEnabled(true);
        JCudaDriver.cuInit(0);
        device = new CUdevice();
        JCudaDriver.cuDeviceGet(device, i);
        context = new CUcontext();
        JCudaDriver.cuCtxCreate(context, 0, device);
    }

    public static void shutdown() {
        JCuda.cudaDeviceReset();
    }

    public static CUmodule compileAndLoad(String str, String str2, boolean z) {
        return loadModule(preparePtxFile(str, str2, z));
    }

    public static String preparePtxFile(String str, String str2, boolean z) {
        File file;
        String str3 = str + ".ptx";
        try {
            file = new File(str3);
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (!z && file.exists()) {
            return str3;
        }
        long nanoTime = System.nanoTime();
        File file2 = new File(str + ".cu");
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file2));
        bufferedWriter.append((CharSequence) str2);
        bufferedWriter.flush();
        bufferedWriter.close();
        String str4 = "-m" + System.getProperty("sun.arch.data.model");
        int[] iArr = new int[1];
        int[] iArr2 = new int[1];
        JCudaDriver.cuDeviceComputeCapability(iArr, iArr2, device);
        String str5 = "nvcc -use_fast_math -arch=sm_" + iArr[0] + "" + iArr2[0] + Charset.SPACE + str4 + " -ptx " + file2.getPath() + " -o " + str3;
        System.out.println("Executing\n" + str5);
        Process exec = Runtime.getRuntime().exec(str5);
        String str6 = new String(toByteArray(exec.getErrorStream()));
        String str7 = new String(toByteArray(exec.getInputStream()));
        try {
            int waitFor = exec.waitFor();
            if (waitFor != 0) {
                System.out.println("nvcc process exitValue " + waitFor);
                System.out.println("errorMessage:\n" + str6);
                System.out.println("outputMessage:\n" + str7);
                throw new IOException("Could not create .ptx file: " + str6);
            }
            System.out.println("Finished creating PTX file");
            System.out.println("Compile time: " + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms");
            return str3;
        } catch (InterruptedException e2) {
            Thread.currentThread().interrupt();
            throw new IOException("Interrupted while waiting for nvcc output", e2);
        }
    }

    public static CUmodule loadModule(String str) {
        CUmodule cUmodule = new CUmodule();
        JCudaDriver.cuModuleLoad(cUmodule, str);
        return cUmodule;
    }

    private static byte[] toByteArray(InputStream inputStream) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] bArr = new byte[CL.CL_GL_OBJECT_BUFFER];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return byteArrayOutputStream.toByteArray();
            }
            byteArrayOutputStream.write(bArr, 0, read);
        }
    }

    public static float toFloat(int i) {
        int i2 = i & 1023;
        int i3 = i & 31744;
        if (i3 == 31744) {
            i3 = 261120;
        } else if (i3 != 0) {
            i3 += 114688;
            if (i2 == 0 && i3 > 115712) {
                return Float.intBitsToFloat(((i & 32768) << 16) | (i3 << 13) | 1023);
            }
        } else if (i2 != 0) {
            i3 = 115712;
            do {
                i2 <<= 1;
                i3 -= 1024;
            } while ((i2 & CL.CL_DBL_MAX_EXP) == 0);
            i2 &= 1023;
        }
        return Float.intBitsToFloat(((i & 32768) << 16) | ((i3 | i2) << 13));
    }

    public static int fromFloat(float f) {
        int floatToIntBits = Float.floatToIntBits(f);
        int i = (floatToIntBits >>> 16) & 32768;
        int i2 = (floatToIntBits & CL.CL_INT_MAX) + CL.CL_DEVICE_TYPE;
        if (i2 >= 1199570944) {
            return (floatToIntBits & CL.CL_INT_MAX) >= 1199570944 ? i2 < 2139095040 ? i | 31744 : i | 31744 | ((floatToIntBits & 8388607) >>> 13) : i | 31743;
        }
        if (i2 >= 947912704) {
            return i | ((i2 - 939524096) >>> 13);
        }
        if (i2 < 855638016) {
            return i;
        }
        int i3 = (floatToIntBits & CL.CL_INT_MAX) >>> 23;
        return i | ((((floatToIntBits & 8388607) | 8388608) + (8388608 >>> (i3 - 102))) >>> (126 - i3));
    }

    public static char[] convertToHalfFloat(float[] fArr) {
        char[] cArr = new char[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            cArr[i] = (char) fromFloat(fArr[i]);
        }
        return cArr;
    }

    public static float[] convertFromHalfFloat(char[] cArr) {
        float[] fArr = new float[cArr.length];
        for (int i = 0; i < cArr.length; i++) {
            fArr[i] = toFloat(cArr[i]);
        }
        return fArr;
    }

    public static float[] flatten(float[][] fArr) {
        float[] fArr2 = new float[fArr.length * fArr[0].length];
        for (int i = 0; i < fArr.length; i++) {
            System.arraycopy(fArr[i], 0, fArr2, i * fArr[0].length, fArr[i].length);
        }
        return fArr2;
    }

    public static double[] flatten(double[][] dArr) {
        double[] dArr2 = new double[dArr.length * dArr[0].length];
        for (int i = 0; i < dArr.length; i++) {
            System.arraycopy(dArr[i], 0, dArr2, i * dArr[0].length, dArr[i].length);
        }
        return dArr2;
    }

    public static float[] flatten(float[][][] fArr) {
        float[] fArr2 = new float[fArr.length * fArr[0].length * fArr[0][0].length];
        for (int i = 0; i < fArr.length; i++) {
            for (int i2 = 0; i2 < fArr[0].length; i2++) {
                System.arraycopy(fArr[i][i2], 0, fArr2, (i * fArr[0].length * fArr[0][0].length) + (i2 * fArr[0][0].length), fArr[i][i2].length);
            }
        }
        return fArr2;
    }

    public static float[] flatten(List<float[]> list) {
        float[] fArr = new float[list.size() * list.get(0).length];
        for (int i = 0; i < list.size(); i++) {
            System.arraycopy(list.get(i), 0, fArr, i * list.get(0).length, list.get(i).length);
        }
        return fArr;
    }

    public static int flatten(int i, int i2, int i3, int i4) {
        return (i3 * i2) + i4;
    }

    public static int unflattenFirst(int i, int i2, int i3) {
        return i3 / i2;
    }

    public static int unflattenSecond(int i, int i2, int i3) {
        return i3 % i2;
    }

    public static String flatten(int i, int i2, int i3, String str) {
        return "(" + (i3 * i2) + " + " + str + ")";
    }

    public static String flatten(int i, int i2, String str, int i3) {
        return "(" + str + " * " + i2 + " + " + i3 + ")";
    }

    public static String flatten(int i, int i2, String str, String str2) {
        return "(" + str + " * " + i2 + " + " + str2 + ")";
    }

    public static int flatten(int i, int i2, int i3, int i4, int i5, int i6) {
        return (i4 * i2 * i3) + (i5 * i3) + i6;
    }

    public static String flatten(int i, int i2, int i3, int i4, int i5, String str) {
        return "(" + ((i4 * i2 * i3) + (i5 * i3)) + " + " + str + ")";
    }

    public static String flatten(int i, int i2, int i3, String str, int i4, int i5) {
        return "(" + str + " * " + (i2 * i3) + " + " + (i4 * i3) + " + " + i5 + ")";
    }

    public static String flatten(int i, int i2, int i3, int i4, String str, int i5) {
        return "(" + (i4 * i2 * i3) + " + " + str + "  * " + i3 + " + " + i5 + ")";
    }

    public static String flatten(int i, int i2, int i3, String str, String str2, int i4) {
        return "(" + str + " * " + (i2 * i3) + " + " + str2 + "  * " + i3 + " + " + i4 + ")";
    }

    public static String flatten(int i, int i2, int i3, String str, int i4, String str2) {
        return "(" + str + " * " + (i2 * i3) + " + " + (i4 * i3) + " + " + str2 + ")";
    }

    public static String flatten(int i, int i2, int i3, int i4, String str, String str2) {
        return "(" + (i4 * i2 * i3) + " + " + str + "  * " + i3 + " + " + str2 + ")";
    }

    public static String flatten(int i, int i2, int i3, String str, String str2, String str3) {
        return "(" + str + " * " + (i2 * i3) + " + " + str2 + "  * " + i3 + " + " + str3 + ")";
    }

    public static float[] extendWithZeros(float[] fArr, int i) {
        float[] fArr2 = new float[i];
        for (int i2 = 0; i2 < i; i2++) {
            if (i2 < fArr.length) {
                fArr2[i2] = fArr[i2];
            } else {
                fArr2[i2] = 0.0f;
            }
        }
        return fArr2;
    }
}
