/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gpu.cuda;

import edu.duke.cs.osprey.gpu.cuda.Gpu;
import edu.duke.cs.osprey.gpu.cuda.GpuStream;
import edu.duke.cs.osprey.tools.FileTools;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import jcuda.Pointer;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;

public class Context {
    private Gpu gpu;
    private CUcontext context;
    private Map<String, CUmodule> kernels;

    public Context(Gpu gpu) {
        this.gpu = gpu;
        this.context = new CUcontext();
        int flags = 4;
        JCudaDriver.cuCtxCreate((CUcontext)this.context, (int)flags, (CUdevice)gpu.getDevice());
        this.kernels = new HashMap<String, CUmodule>();
    }

    public Gpu getGpu() {
        return this.gpu;
    }

    public synchronized CUmodule getKernel(String name) throws IOException {
        CUmodule kernel = this.kernels.get(name);
        if (kernel == null) {
            File kernelFile = new FileTools.ResourcePathRoot("/gpuKernels/cuda").extractToTempFile(String.format("%s.bin", name));
            kernel = new CUmodule();
            JCudaDriver.cuModuleLoad((CUmodule)kernel, (String)kernelFile.getAbsolutePath());
            this.kernels.put(name, kernel);
        }
        return kernel;
    }

    public CUdeviceptr malloc(long numBytes) {
        CUdeviceptr pdBuf = new CUdeviceptr();
        JCudaDriver.cuMemAlloc((CUdeviceptr)pdBuf, (long)numBytes);
        return pdBuf;
    }

    public void free(CUdeviceptr pdBuf) {
        JCudaDriver.cuMemFree((CUdeviceptr)pdBuf);
    }

    public void uploadAsync(CUdeviceptr pdBuf, Pointer phBuf, long numBytes, GpuStream stream) {
        JCudaDriver.cuMemcpyHtoDAsync((CUdeviceptr)pdBuf, (Pointer)phBuf, (long)numBytes, (CUstream)stream.getStream());
    }

    public void downloadAsync(Pointer phBuf, CUdeviceptr pdBuf, long numBytes, GpuStream stream) {
        JCudaDriver.cuMemcpyDtoHAsync((Pointer)phBuf, (CUdeviceptr)pdBuf, (long)numBytes, (CUstream)stream.getStream());
    }

    public void pinBuffer(Pointer phBuf, long numBytes) {
        if (numBytes <= 0L) {
            throw new IllegalArgumentException("bad buffer size: " + numBytes + " bytes");
        }
        JCudaDriver.cuMemHostRegister((Pointer)phBuf, (long)numBytes, (int)0);
    }

    public void unpinBuffer(Pointer phBuf) {
        JCudaDriver.cuMemHostUnregister((Pointer)phBuf);
    }

    public void launchKernel(CUfunction func, int gridBlocks, int blockThreads, int sharedMemBytes, Pointer pArgs, GpuStream stream) {
        JCudaDriver.cuLaunchKernel((CUfunction)func, (int)gridBlocks, (int)1, (int)1, (int)blockThreads, (int)1, (int)1, (int)sharedMemBytes, (CUstream)stream.getStream(), (Pointer)pArgs, null);
    }

    public void waitForGpu() {
        JCudaDriver.cuCtxSynchronize();
    }

    public void attachCurrentThread() {
        JCudaDriver.cuCtxSetCurrent((CUcontext)this.context);
    }

    public synchronized void cleanup() {
        try {
            for (CUmodule kernel : this.kernels.values()) {
                JCudaDriver.cuModuleUnload((CUmodule)kernel);
            }
            this.kernels.clear();
            JCudaDriver.cuCtxDestroy((CUcontext)this.context);
        }
        catch (Throwable t) {
            t.printStackTrace(System.err);
        }
    }
}

