/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gpu.cuda;

import edu.duke.cs.osprey.gpu.cuda.Context;
import edu.duke.cs.osprey.gpu.cuda.Gpu;
import edu.duke.cs.osprey.gpu.cuda.GpuStream;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;

public class Kernel {
    private GpuStream stream;
    private CUmodule module;

    public Kernel(GpuStream stream, String filename) {
        if (stream == null) {
            throw new IllegalArgumentException("stream can't be null");
        }
        this.stream = stream;
        stream.getContext().attachCurrentThread();
        try {
            this.module = this.getContext().getKernel(filename);
        }
        catch (IOException ex) {
            throw new Error("can't load Cuda kernel: " + filename, ex);
        }
    }

    public GpuStream getStream() {
        return this.stream;
    }

    public Context getContext() {
        return this.stream.getContext();
    }

    public Function makeFunction(String name) {
        return new Function(name);
    }

    public void waitForGpu() {
        this.getStream().waitForGpu();
    }

    public class Function {
        private CUfunction func = new CUfunction();
        private Pointer pArgs;
        public int numBlocks;
        public int blockThreads;
        public SharedMemCalculator sharedMemCalc;

        public Function(String name) {
            JCudaDriver.cuModuleGetFunction((CUfunction)this.func, (CUmodule)Kernel.this.module, (String)name);
            this.pArgs = null;
            this.numBlocks = 1;
            this.blockThreads = 1;
            this.sharedMemCalc = new SharedMemCalculator.None();
        }

        public void setArgs(Pointer val) {
            this.pArgs = val;
        }

        public void runAsync() {
            Kernel.this.getContext().launchKernel(this.func, this.numBlocks, this.blockThreads, this.sharedMemCalc.calcBytes(this.blockThreads), this.pArgs, Kernel.this.stream);
        }

        public int calcMaxBlockThreads() {
            Gpu gpu = Kernel.this.getContext().getGpu();
            for (int blockThreads = gpu.getMaxBlockThreads(); blockThreads > 0; blockThreads -= gpu.getWarpThreads()) {
                if (!this.canLaunch(blockThreads)) continue;
                return blockThreads;
            }
            throw new Error("can't determine thread count for kernel launch, all thread counts failed");
        }

        public int getBestBlockThreads(AtomicInteger blockThreads) {
            return blockThreads.updateAndGet(val -> {
                if (val == -1) {
                    val = this.calcMaxBlockThreads();
                }
                return val;
            });
        }

        private boolean canLaunch(int blockThreads) {
            try {
                int numBlocks = 1;
                Kernel.this.getContext().launchKernel(this.func, numBlocks, blockThreads, this.sharedMemCalc.calcBytes(blockThreads), this.pArgs, Kernel.this.stream);
                Kernel.this.stream.waitForGpu();
                return true;
            }
            catch (CudaException ex) {
                if (ex.getMessage().equalsIgnoreCase("CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES")) {
                    return false;
                }
                throw ex;
            }
        }
    }

    public static interface SharedMemCalculator {
        public int calcBytes(int var1);

        public static class None
        implements SharedMemCalculator {
            @Override
            public int calcBytes(int blockThreads) {
                return 0;
            }
        }
    }
}

