/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gpu.cuda.kernels;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import edu.duke.cs.osprey.dof.DegreeOfFreedom;
import edu.duke.cs.osprey.dof.FreeDihedral;
import edu.duke.cs.osprey.energy.forcefield.BigForcefieldEnergy;
import edu.duke.cs.osprey.gpu.cuda.CUBuffer;
import edu.duke.cs.osprey.gpu.cuda.GpuStream;
import edu.duke.cs.osprey.gpu.cuda.Kernel;
import edu.duke.cs.osprey.minimization.Minimizer;
import edu.duke.cs.osprey.minimization.MoleculeModifierAndScorer;
import edu.duke.cs.osprey.minimization.MoleculeObjectiveFunction;
import edu.duke.cs.osprey.minimization.ObjectiveFunction;
import edu.duke.cs.osprey.structure.Residue;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.List;
import jcuda.NativePointerObject;
import jcuda.Pointer;

public class CCDKernelCuda
extends Kernel {
    private BigForcefieldEnergy ffenergy;
    private int ffSequenceNumber;
    private Kernel.Function func;
    private static Integer blockThreads = null;
    private CUBuffer<DoubleBuffer> coords;
    private CUBuffer<IntBuffer> atomFlags;
    private CUBuffer<DoubleBuffer> precomputed;
    private CUBuffer<ByteBuffer> ffargs;
    private CUBuffer<ByteBuffer> dofargs;
    private CUBuffer<IntBuffer> subsetTables;
    private CUBuffer<IntBuffer> dihedralIndices;
    private CUBuffer<IntBuffer> rotatedIndices;
    private CUBuffer<DoubleBuffer> xAndBounds;
    private CUBuffer<DoubleBuffer> ccdOut;
    private List<DofInfo> dofInfos;

    public CCDKernelCuda(GpuStream stream) throws IOException {
        super(stream, "ccd");
        this.ffargs = stream.makeByteBuffer(48);
    }

    @Deprecated
    public void init(MoleculeModifierAndScorer mof) {
        this.init(new MoleculeObjectiveFunction(mof));
    }

    public void init(MoleculeObjectiveFunction mof) {
        GpuStream stream = this.getStream();
        stream.getContext().attachCurrentThread();
        if (!(mof.efunc instanceof BigForcefieldEnergy)) {
            throw new Error("CCD kernel needs a " + BigForcefieldEnergy.class.getSimpleName() + ", not a " + mof.efunc.getClass().getSimpleName() + ". this is a bug.");
        }
        this.ffenergy = (BigForcefieldEnergy)mof.efunc;
        this.ffSequenceNumber = this.ffenergy.getFullSubset().handleChemicalChanges();
        this.ffenergy.updateCoords();
        this.coords = stream.makeOrExpandBuffer(this.coords, this.ffenergy.getCoords());
        this.atomFlags = stream.makeOrExpandBuffer(this.atomFlags, this.ffenergy.getAtomFlags());
        this.precomputed = stream.makeOrExpandBuffer(this.precomputed, this.ffenergy.getPrecomputed());
        ByteBuffer argsBuf = this.ffargs.getHostBuffer();
        argsBuf.rewind();
        argsBuf.putInt(this.ffenergy.getFullSubset().getNumAtomPairs());
        argsBuf.putInt(this.ffenergy.getFullSubset().getNum14AtomPairs());
        argsBuf.putDouble(this.ffenergy.getParams().coulombFactor);
        argsBuf.putDouble(this.ffenergy.getParams().scaledCoulombFactor);
        argsBuf.putDouble(this.ffenergy.getParams().solvationCutoff2);
        argsBuf.putDouble(this.ffenergy.getFullSubset().getInternalSolvationEnergy());
        argsBuf.put((byte)(this.ffenergy.getParams().useDistDependentDielectric ? 1 : 0));
        argsBuf.put((byte)(this.ffenergy.getParams().useHElectrostatics ? 1 : 0));
        argsBuf.put((byte)(this.ffenergy.getParams().useHVdw ? 1 : 0));
        argsBuf.put((byte)(this.ffenergy.getParams().useEEF1 ? 1 : 0));
        this.atomFlags.uploadAsync();
        this.precomputed.uploadAsync();
        this.ffargs.uploadAsync();
        this.dofInfos = new ArrayList<DofInfo>();
        int subsetsSize = 0;
        int numRotatedAtoms = 0;
        for (int d = 0; d < mof.getNumDOFs(); ++d) {
            DegreeOfFreedom dofBase = mof.pmol.dofs.get(d);
            if (!(dofBase instanceof FreeDihedral)) {
                throw new Error("degree-of-freedom type " + dofBase.getClass().getSimpleName() + " not yet supported by CCD kernel. Use CPU minimizer with GPU energy function instead");
            }
            FreeDihedral dof = (FreeDihedral)dofBase;
            BigForcefieldEnergy.Subset subset = (BigForcefieldEnergy.Subset)mof.getEfunc(d);
            DofInfo dofInfo = new DofInfo(this, dof, subset);
            this.dofInfos.add(dofInfo);
            subsetsSize += dofInfo.subset.getNumAtomPairs();
            numRotatedAtoms += dofInfo.rotatedIndices.size();
        }
        this.subsetTables = stream.makeOrExpandIntBuffer(this.subsetTables, subsetsSize);
        IntBuffer subsetTablesBuf = this.subsetTables.getHostBuffer();
        subsetTablesBuf.clear();
        this.dofargs = stream.makeOrExpandByteBuffer(this.dofargs, this.dofInfos.size() * 48);
        ByteBuffer dofargsBuf = this.dofargs.getHostBuffer();
        dofargsBuf.clear();
        this.dihedralIndices = stream.makeOrExpandIntBuffer(this.dihedralIndices, this.dofInfos.size() * 4);
        IntBuffer dihedralIndicesBuf = this.dihedralIndices.getHostBuffer();
        dihedralIndicesBuf.clear();
        this.rotatedIndices = stream.makeOrExpandIntBuffer(this.rotatedIndices, numRotatedAtoms);
        IntBuffer rotatedIndicesBuf = this.rotatedIndices.getHostBuffer();
        rotatedIndicesBuf.clear();
        int maxNumCoords = 0;
        for (int d = 0; d < this.dofInfos.size(); ++d) {
            int i;
            DofInfo dofInfo = this.dofInfos.get(d);
            int firstCoord = dofInfo.atomOffset * 3;
            int lastCoord = (dofInfo.atomOffset + dofInfo.numAtoms) * 3 - 1;
            int numCoords = lastCoord - firstCoord + 1;
            maxNumCoords = Math.max(maxNumCoords, numCoords);
            dofargsBuf.putInt(subsetTablesBuf.position());
            dofargsBuf.putInt(dofInfo.subset.getNumAtomPairs());
            dofargsBuf.putInt(dofInfo.subset.getNum14AtomPairs());
            dofargsBuf.putInt(rotatedIndicesBuf.position());
            dofargsBuf.putInt(dofInfo.rotatedIndices.size());
            dofargsBuf.putInt(firstCoord);
            dofargsBuf.putInt(lastCoord);
            dofargsBuf.putInt(0);
            dofargsBuf.putDouble(dofInfo.subset.getInternalSolvationEnergy());
            dofInfo.subset.getSubsetTable().rewind();
            subsetTablesBuf.put(dofInfo.subset.getSubsetTable());
            for (i = 0; i < dofInfo.dihedralIndices.length; ++i) {
                dihedralIndicesBuf.put(dofInfo.dihedralIndices[i]);
            }
            for (i = 0; i < dofInfo.rotatedIndices.size(); ++i) {
                rotatedIndicesBuf.put(dofInfo.rotatedIndices.get(i));
            }
        }
        final int fMaxNumCoords = maxNumCoords;
        this.dofargs.uploadAsync();
        this.subsetTables.uploadAsync();
        this.dihedralIndices.uploadAsync();
        this.rotatedIndices.uploadAsync();
        this.xAndBounds = stream.makeOrExpandDoubleBuffer(this.xAndBounds, this.dofInfos.size() * 3);
        this.ccdOut = stream.makeOrExpandDoubleBuffer(this.ccdOut, this.dofInfos.size() + 1);
        this.func = this.makeFunction("ccd");
        this.func.numBlocks = 1;
        this.func.sharedMemCalc = new Kernel.SharedMemCalculator(){

            @Override
            public int calcBytes(int blockThreads) {
                return blockThreads * 8 + fMaxNumCoords * 8 + CCDKernelCuda.this.dofInfos.size() * 8 + CCDKernelCuda.this.dofInfos.size() * 8 + CCDKernelCuda.this.dofInfos.size() * 8;
            }
        };
        this.func.setArgs(Pointer.to((NativePointerObject[])new NativePointerObject[]{this.coords.getDevicePointer(), this.atomFlags.getDevicePointer(), this.precomputed.getDevicePointer(), this.ffargs.getDevicePointer(), this.subsetTables.getDevicePointer(), this.dihedralIndices.getDevicePointer(), this.rotatedIndices.getDevicePointer(), this.dofargs.getDevicePointer(), Pointer.to((int[])new int[]{maxNumCoords}), this.xAndBounds.getDevicePointer(), Pointer.to((int[])new int[]{this.dofInfos.size()}), this.ccdOut.getDevicePointer()}));
        if (blockThreads == null) {
            blockThreads = this.func.calcMaxBlockThreads();
        }
        this.func.blockThreads = blockThreads;
    }

    public void uploadCoordsAsync() {
        this.ffenergy.updateCoords();
        this.coords.uploadAsync();
    }

    public void runAsync(DoubleMatrix1D x, ObjectiveFunction.DofBounds dofBounds) {
        if (this.ffenergy.getFullSubset().handleChemicalChanges() != this.ffSequenceNumber) {
            throw new Error("don't re-use kernel instances after chemical changes. This is a bug");
        }
        DoubleBuffer buf = this.xAndBounds.getHostBuffer();
        buf.clear();
        int numDofs = this.dofInfos.size();
        for (int d = 0; d < numDofs; ++d) {
            buf.put(Math.toRadians(x.get(d)));
            buf.put(Math.toRadians(dofBounds.getMin(d)));
            buf.put(Math.toRadians(dofBounds.getMax(d)));
        }
        this.xAndBounds.uploadAsync();
        this.func.runAsync();
    }

    public Minimizer.Result downloadResultSync() {
        int numDofs = this.dofInfos.size();
        Minimizer.Result result = new Minimizer.Result(DoubleFactory1D.dense.make(numDofs), 0.0);
        DoubleBuffer buf = this.ccdOut.downloadSync();
        buf.rewind();
        result.energy = buf.get();
        for (int d = 0; d < numDofs; ++d) {
            result.dofValues.set(d, Math.toDegrees(buf.get()));
        }
        return result;
    }

    public Minimizer.Result runSync(DoubleMatrix1D x, ObjectiveFunction.DofBounds dofBounds) {
        this.runAsync(x, dofBounds);
        return this.downloadResultSync();
    }

    public void cleanup() {
        if (this.coords != null) {
            this.coords.cleanup();
            this.atomFlags.cleanup();
            this.precomputed.cleanup();
            this.ffargs.cleanup();
            this.dofargs.cleanup();
            this.subsetTables.cleanup();
            this.dihedralIndices.cleanup();
            this.rotatedIndices.cleanup();
            this.xAndBounds.cleanup();
            this.ccdOut.cleanup();
            this.coords = null;
        }
    }

    protected void finalize() throws Throwable {
        try {
            if (this.coords != null) {
                System.err.println("WARNING: " + this.getClass().getName() + " was garbage collected, but not cleaned up. Attempting cleanup now");
                this.cleanup();
            }
        }
        finally {
            super.finalize();
        }
    }

    private class DofInfo {
        public final Residue res;
        public final BigForcefieldEnergy.Subset subset;
        public final int[] dihedralIndices;
        public final List<Integer> rotatedIndices;
        public final int atomOffset;
        public final int numAtoms;

        public DofInfo(CCDKernelCuda cCDKernelCuda, FreeDihedral dof, BigForcefieldEnergy.Subset subset) {
            this.res = dof.getResidue();
            this.subset = subset;
            this.dihedralIndices = this.res.template.getDihedralDefiningAtoms(dof.getDihedralNumber());
            this.rotatedIndices = this.res.template.getDihedralRotatedAtoms(dof.getDihedralNumber());
            this.atomOffset = cCDKernelCuda.ffenergy.getAtomOffset(this.res);
            this.numAtoms = this.res.atoms.size();
        }
    }
}

