/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gpu.cuda.kernels;

import edu.duke.cs.osprey.energy.forcefield.BigForcefieldEnergy;
import edu.duke.cs.osprey.gpu.ForcefieldKernel;
import edu.duke.cs.osprey.gpu.cuda.CUBuffer;
import edu.duke.cs.osprey.gpu.cuda.GpuStream;
import edu.duke.cs.osprey.gpu.cuda.Kernel;
import edu.duke.cs.osprey.tools.MathTools;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import jcuda.NativePointerObject;
import jcuda.Pointer;

public class ForcefieldKernelCuda
extends Kernel
implements ForcefieldKernel {
    private Kernel.Function func;
    private CUBuffer<DoubleBuffer> coords;
    private CUBuffer<IntBuffer> atomFlags;
    private CUBuffer<DoubleBuffer> precomputed;
    private CUBuffer<IntBuffer> subsetTable;
    private CUBuffer<DoubleBuffer> energies;
    private CUBuffer<ByteBuffer> args;
    private BigForcefieldEnergy ffenergy;
    private BigForcefieldEnergy.Subset subset;

    public ForcefieldKernelCuda(GpuStream stream, BigForcefieldEnergy ffenergy) throws IOException {
        super(stream, "forcefield");
        this.ffenergy = ffenergy;
        this.func = this.makeFunction("calc");
        this.func.blockThreads = 512;
        this.func.sharedMemCalc = new Kernel.SharedMemCalculator(this){

            @Override
            public int calcBytes(int blockThreads) {
                return blockThreads * 8;
            }
        };
        this.coords = this.getStream().makeBuffer(ffenergy.getCoords());
        this.atomFlags = this.getStream().makeBuffer(ffenergy.getAtomFlags());
        this.precomputed = this.getStream().makeBuffer(ffenergy.getPrecomputed());
        this.subsetTable = this.getStream().makeIntBuffer(ffenergy.getFullSubset().getNumAtomPairs());
        this.energies = this.getStream().makeDoubleBuffer(ForcefieldKernelCuda.getEnergySize(ffenergy.getFullSubset(), this.func.blockThreads));
        this.atomFlags.uploadAsync();
        this.precomputed.uploadAsync();
        this.args = this.getStream().makeByteBuffer(40);
        ByteBuffer argsBuf = this.args.getHostBuffer();
        argsBuf.rewind();
        argsBuf.putInt(0);
        argsBuf.putInt(0);
        argsBuf.putDouble(ffenergy.getParams().coulombFactor);
        argsBuf.putDouble(ffenergy.getParams().scaledCoulombFactor);
        argsBuf.putDouble(ffenergy.getParams().solvationCutoff2);
        argsBuf.put((byte)(ffenergy.getParams().useDistDependentDielectric ? 1 : 0));
        argsBuf.put((byte)(ffenergy.getParams().useHElectrostatics ? 1 : 0));
        argsBuf.put((byte)(ffenergy.getParams().useHVdw ? 1 : 0));
        argsBuf.put((byte)0);
        argsBuf.put((byte)(ffenergy.getParams().useEEF1 ? 1 : 0));
        argsBuf.flip();
        this.subset = null;
        this.setSubsetInternal(ffenergy.getFullSubset());
        this.func.setArgs(Pointer.to((NativePointerObject[])new NativePointerObject[]{this.coords.getDevicePointer(), this.atomFlags.getDevicePointer(), this.precomputed.getDevicePointer(), this.subsetTable.getDevicePointer(), this.args.getDevicePointer(), this.energies.getDevicePointer()}));
    }

    public CUBuffer<DoubleBuffer> getCoords() {
        return this.coords;
    }

    public CUBuffer<IntBuffer> getAtomFlags() {
        return this.atomFlags;
    }

    public CUBuffer<DoubleBuffer> getPrecomputed() {
        return this.precomputed;
    }

    public CUBuffer<IntBuffer> getSubsetTable() {
        return this.subsetTable;
    }

    public CUBuffer<DoubleBuffer> getEnergies() {
        return this.energies;
    }

    public CUBuffer<ByteBuffer> getArgs() {
        return this.args;
    }

    @Override
    public BigForcefieldEnergy getForcefield() {
        return this.ffenergy;
    }

    @Override
    public BigForcefieldEnergy.Subset getSubset() {
        return this.subset;
    }

    @Override
    public boolean setSubset(BigForcefieldEnergy.Subset subset) {
        return this.setSubsetInternal(subset);
    }

    private boolean setSubsetInternal(BigForcefieldEnergy.Subset subset) {
        if (this.subset == subset) {
            return false;
        }
        this.subset = subset;
        boolean useSubset = subset.getSubsetTable() != null;
        this.func.numBlocks = MathTools.divUp(subset.getNumAtomPairs(), this.func.blockThreads);
        this.getStream().getContext().attachCurrentThread();
        ByteBuffer buf = this.args.getHostBuffer();
        buf.putInt(0, subset.getNumAtomPairs());
        buf.putInt(4, subset.getNum14AtomPairs());
        buf.put(35, (byte)(useSubset ? 1 : 0));
        buf.rewind();
        this.args.uploadAsync();
        if (useSubset) {
            this.subsetTable.getHostBuffer().clear();
            subset.getSubsetTable().rewind();
            this.subsetTable.getHostBuffer().put(subset.getSubsetTable());
            this.subsetTable.getHostBuffer().flip();
            this.subsetTable.uploadAsync();
        }
        return true;
    }

    @Override
    public void runAsync() {
        this.getStream().getContext().attachCurrentThread();
        this.func.runAsync();
    }

    private static int getEnergySize(BigForcefieldEnergy.Subset subset, int blockThreads) {
        return MathTools.divUp(subset.getNumAtomPairs(), blockThreads);
    }

    @Override
    public void cleanup() {
        if (this.coords != null) {
            this.coords.cleanup();
            this.atomFlags.cleanup();
            this.precomputed.cleanup();
            this.subsetTable.cleanup();
            this.args.cleanup();
            this.energies.cleanup();
            this.coords = null;
        }
    }

    protected void finalize() throws Throwable {
        try {
            if (this.coords != null) {
                System.err.println("WARNING: " + this.getClass().getName() + " was garbage collected, but not cleaned up. Attempting cleanup now");
                this.cleanup();
            }
        }
        finally {
            super.finalize();
        }
    }

    @Override
    public void uploadCoordsAsync() {
        this.getStream().getContext().attachCurrentThread();
        this.ffenergy.updateCoords();
        this.coords.uploadAsync();
    }

    @Override
    public double downloadEnergySync() {
        this.getStream().getContext().attachCurrentThread();
        this.energies.downloadSync();
        DoubleBuffer buf = this.energies.getHostBuffer();
        double energy = this.subset.getInternalSolvationEnergy();
        buf.rewind();
        int n = ForcefieldKernelCuda.getEnergySize(this.subset, this.func.blockThreads);
        for (int i = 0; i < n; ++i) {
            energy += buf.get();
        }
        return energy;
    }
}

