/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gpu.cuda.kernels;

import edu.duke.cs.osprey.dof.DegreeOfFreedom;
import edu.duke.cs.osprey.energy.EnergyFunction;
import edu.duke.cs.osprey.energy.ResidueInteractions;
import edu.duke.cs.osprey.energy.forcefield.ForcefieldParams;
import edu.duke.cs.osprey.energy.forcefield.ResPairCache;
import edu.duke.cs.osprey.gpu.cuda.CUBuffer;
import edu.duke.cs.osprey.gpu.cuda.GpuStream;
import edu.duke.cs.osprey.gpu.cuda.GpuStreamPool;
import edu.duke.cs.osprey.gpu.cuda.Kernel;
import edu.duke.cs.osprey.structure.Molecule;
import edu.duke.cs.osprey.structure.Residue;
import edu.duke.cs.osprey.structure.Residues;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import jcuda.NativePointerObject;
import jcuda.Pointer;

public class ResidueForcefieldEnergyCuda
extends Kernel
implements EnergyFunction.DecomposableByDof,
EnergyFunction.NeedsCleanup {
    private static final long serialVersionUID = 4015880661919715967L;
    public final GpuStreamPool streams;
    public final ForcefieldParams ffparams;
    public final ResidueInteractions inters;
    public final Residues residues;
    public final boolean isBroken;
    private CUBuffer<ByteBuffer> data;
    private CUBuffer<DoubleBuffer> coords;
    private CUBuffer<IntBuffer> allIndices;
    private CUBuffer<DoubleBuffer> energy;
    private static final int HeaderBytes = 24;
    private static final int ResPairBytes = 40;
    private static final int AtomPairBytes = 32;
    private static final int EEF1Bytes = 48;
    private ResPairCache.ResPair[] resPairs = new ResPairCache.ResPair[0];
    private Map<Residue, Subset> subsets;
    private Kernel.Function func;
    private static AtomicInteger blockThreads = new AtomicInteger(-1);

    public ResidueForcefieldEnergyCuda(GpuStreamPool streams, ResPairCache resPairCache, ResidueInteractions inters, Molecule mol) {
        this(streams, resPairCache, inters, mol.residues);
    }

    public ResidueForcefieldEnergyCuda(GpuStreamPool streams, ResPairCache resPairCache, ResidueInteractions inters, Residues residues) {
        super(streams.checkout(), "residueForcefield");
        int i;
        this.streams = streams;
        this.ffparams = resPairCache.ffparams;
        this.inters = inters;
        this.residues = inters.filter(residues);
        for (Residue res : this.residues) {
            if (res.confProblems.isEmpty()) continue;
            this.isBroken = true;
            return;
        }
        this.isBroken = false;
        ForcefieldParams.SolvationForcefield.ResiduesInfo solvInfo = null;
        if (this.ffparams.solvationForcefield != null) {
            solvInfo = this.ffparams.solvationForcefield.makeInfo(this.ffparams, this.residues);
        }
        this.resPairs = new ResPairCache.ResPair[inters.size()];
        int index = 0;
        for (ResidueInteractions.Pair pair : inters) {
            this.resPairs[index++] = resPairCache.get(this.residues, pair, solvInfo);
        }
        this.subsets = null;
        int[] atomOffsetsByResIndex = new int[this.residues.size()];
        Arrays.fill(atomOffsetsByResIndex, -1);
        int atomOffset = 0;
        int numAtoms = 0;
        for (int i2 = 0; i2 < this.residues.size(); ++i2) {
            Residue res = (Residue)this.residues.get(i2);
            atomOffsetsByResIndex[i2] = atomOffset;
            atomOffset += 3 * res.atoms.size();
            numAtoms += res.atoms.size();
        }
        GpuStream stream = this.getStream();
        this.coords = stream.doubleBuffers.checkout(numAtoms * 3);
        int totalNumAtomPairs = 0;
        for (int i3 = 0; i3 < this.resPairs.length; ++i3) {
            totalNumAtomPairs += this.resPairs[i3].info.numAtomPairs;
        }
        int atomPairBytes = 32 + (this.ffparams.solvationForcefield == ForcefieldParams.SolvationForcefield.EEF1 ? 48 : 0);
        this.data = stream.byteBuffers.checkout(24 + 48 * this.resPairs.length + atomPairBytes * totalNumAtomPairs);
        ByteBuffer databuf = this.data.getHostBuffer();
        long flags = this.ffparams.hElect ? 1L : 0L;
        flags <<= 1;
        flags |= this.ffparams.hVDW ? 1L : 0L;
        flags <<= 1;
        flags |= this.ffparams.distDepDielect ? 1L : 0L;
        flags <<= 1;
        databuf.putLong(flags |= this.ffparams.solvationForcefield == ForcefieldParams.SolvationForcefield.EEF1 ? 1L : 0L);
        double coulombFactor = 332.0 / this.ffparams.dielectric;
        double scaledCoulombFactor = coulombFactor * this.ffparams.forcefld.coulombScaling;
        databuf.putDouble(coulombFactor);
        databuf.putDouble(scaledCoulombFactor);
        long offset = 24 + 8 * this.resPairs.length;
        for (i = 0; i < this.resPairs.length; ++i) {
            databuf.putLong(offset);
            offset += (long)(40 + atomPairBytes * this.resPairs[i].info.numAtomPairs);
        }
        for (i = 0; i < this.resPairs.length; ++i) {
            ResPairCache.ResPair resPair = this.resPairs[i];
            databuf.putLong(resPair.info.numAtomPairs);
            databuf.putLong(atomOffsetsByResIndex[resPair.resIndex1]);
            databuf.putLong(atomOffsetsByResIndex[resPair.resIndex2]);
            databuf.putDouble(resPair.weight);
            databuf.putDouble(resPair.offset + resPair.solvEnergy);
            for (int j = 0; j < resPair.info.numAtomPairs; ++j) {
                databuf.putLong(resPair.info.flags[j]);
            }
            for (int k = 0; k < resPair.info.numPrecomputedPerAtomPair; ++k) {
                for (int j = 0; j < resPair.info.numAtomPairs; ++j) {
                    databuf.putDouble(resPair.info.precomputed[j * resPair.info.numPrecomputedPerAtomPair + k]);
                }
            }
        }
        databuf.flip();
        this.data.uploadAsync();
        this.allIndices = stream.intBuffers.checkout(this.resPairs.length);
        IntBuffer allIndicesBuf = this.allIndices.getHostBuffer();
        for (int i4 = 0; i4 < this.resPairs.length; ++i4) {
            allIndicesBuf.put(i4);
        }
        allIndicesBuf.flip();
        this.allIndices.uploadAsync();
        this.energy = stream.doubleBuffers.checkout(1);
        this.func = this.makeFunction("calc");
        this.func.numBlocks = 1;
        this.func.sharedMemCalc = blockThreads -> blockThreads * 8;
        this.func.setArgs(Pointer.to((NativePointerObject[])new NativePointerObject[]{this.coords.getDevicePointer(), this.data.getDevicePointer(), Pointer.to((int[])new int[]{0}), Pointer.to((int[])new int[]{0}), this.energy.getDevicePointer()}));
        this.func.blockThreads = this.func.getBestBlockThreads(ResidueForcefieldEnergyCuda.blockThreads);
    }

    @Override
    public void clean() {
        GpuStream stream = this.getStream();
        if (this.coords != null) {
            stream.doubleBuffers.release(this.coords);
            this.coords = null;
        }
        if (this.data != null) {
            stream.byteBuffers.release(this.data);
            this.data = null;
        }
        if (this.allIndices != null) {
            stream.intBuffers.release(this.allIndices);
            this.allIndices = null;
        }
        if (this.energy != null) {
            stream.doubleBuffers.release(this.energy);
            this.energy = null;
        }
        if (this.subsets != null) {
            for (Subset subset : this.subsets.values()) {
                if (subset.indices == null) continue;
                stream.intBuffers.release(subset.indices);
            }
            this.subsets = null;
        }
        this.streams.release(stream);
    }

    @Override
    public double getEnergy() {
        return this.getEnergy(this.allIndices);
    }

    private double getEnergy(CUBuffer<IntBuffer> indices) {
        if (this.isBroken) {
            return Double.POSITIVE_INFINITY;
        }
        this.getStream().getContext().attachCurrentThread();
        DoubleBuffer coordsbuf = this.coords.getHostBuffer();
        coordsbuf.clear();
        for (Residue res : this.residues) {
            coordsbuf.put(res.coords);
        }
        coordsbuf.clear();
        this.coords.uploadAsync();
        this.func.setArgs(Pointer.to((NativePointerObject[])new NativePointerObject[]{this.coords.getDevicePointer(), this.data.getDevicePointer(), Pointer.to((int[])new int[]{indices.getHostBuffer().limit()}), indices.getDevicePointer(), this.energy.getDevicePointer()}));
        this.func.runAsync();
        DoubleBuffer buf = this.energy.downloadSync();
        buf.rewind();
        return buf.get();
    }

    @Override
    public List<EnergyFunction> decomposeByDof(Molecule mol, List<DegreeOfFreedom> dofs) {
        if (this.subsets == null) {
            this.subsets = new HashMap<Residue, Subset>();
        }
        ArrayList<EnergyFunction> efuncs = new ArrayList<EnergyFunction>();
        for (DegreeOfFreedom dof : dofs) {
            Residue res = dof.getResidue();
            if (res == null) {
                efuncs.add(this);
                continue;
            }
            Subset subset = this.subsets.get(res);
            if (subset == null) {
                subset = new Subset(res);
                this.subsets.put(res, subset);
            }
            efuncs.add(subset);
        }
        return efuncs;
    }

    private class Subset
    implements EnergyFunction {
        private static final long serialVersionUID = -1749739381007657718L;
        private CUBuffer<IntBuffer> indices;

        public Subset(Residue res) {
            int num = 0;
            for (int i = 0; i < ResidueForcefieldEnergyCuda.this.resPairs.length; ++i) {
                ResPairCache.ResPair resPair = ResidueForcefieldEnergyCuda.this.resPairs[i];
                if (resPair.res1 != res && resPair.res2 != res) continue;
                ++num;
            }
            if (num <= 0) {
                return;
            }
            this.indices = ResidueForcefieldEnergyCuda.this.getStream().intBuffers.checkout(num);
            IntBuffer indicesbuf = this.indices.getHostBuffer();
            for (int i = 0; i < ResidueForcefieldEnergyCuda.this.resPairs.length; ++i) {
                ResPairCache.ResPair resPair = ResidueForcefieldEnergyCuda.this.resPairs[i];
                if (resPair.res1 != res && resPair.res2 != res) continue;
                indicesbuf.put(i);
            }
            indicesbuf.flip();
            this.indices.uploadAsync();
        }

        @Override
        public double getEnergy() {
            if (this.indices == null) {
                return Double.NaN;
            }
            return ResidueForcefieldEnergyCuda.this.getEnergy(this.indices);
        }
    }
}

