/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gpu.cuda.kernels;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import edu.duke.cs.osprey.dof.DegreeOfFreedom;
import edu.duke.cs.osprey.dof.FreeDihedral;
import edu.duke.cs.osprey.energy.forcefield.ForcefieldParams;
import edu.duke.cs.osprey.energy.forcefield.ResPairCache;
import edu.duke.cs.osprey.energy.forcefield.ResidueForcefieldEnergy;
import edu.duke.cs.osprey.gpu.cuda.CUBuffer;
import edu.duke.cs.osprey.gpu.cuda.GpuStream;
import edu.duke.cs.osprey.gpu.cuda.GpuStreamPool;
import edu.duke.cs.osprey.gpu.cuda.Kernel;
import edu.duke.cs.osprey.minimization.Minimizer;
import edu.duke.cs.osprey.minimization.MoleculeObjectiveFunction;
import edu.duke.cs.osprey.minimization.ObjectiveFunction;
import edu.duke.cs.osprey.structure.Residue;
import edu.duke.cs.osprey.tools.MathTools;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import jcuda.NativePointerObject;
import jcuda.Pointer;

public class ResidueCudaCCDMinimizer
extends Kernel
implements Minimizer.NeedsCleanup {
    public final GpuStreamPool streams;
    private Kernel.Function func;
    private static AtomicInteger blockThreads = new AtomicInteger(-1);
    private MoleculeObjectiveFunction mof;
    private ResidueForcefieldEnergy efunc;
    private List<Dihedral> dihedrals;
    private CUBuffer<ByteBuffer> data;
    private CUBuffer<DoubleBuffer> coords;
    private CUBuffer<DoubleBuffer> xin;
    private CUBuffer<DoubleBuffer> out;
    private static final int HeaderBytes = 32;
    private static final int DihedralBytes = 48;
    private static final int ResPairBytes = 40;
    private static final int AtomPairBytes = 80;

    public ResidueCudaCCDMinimizer(GpuStreamPool streams, ObjectiveFunction f) {
        super(streams.checkout(), "residueCcd");
        int j;
        int d;
        this.streams = streams;
        GpuStream stream = this.getStream();
        if (!(f instanceof MoleculeObjectiveFunction)) {
            throw new Error("objective function should be a " + MoleculeObjectiveFunction.class.getSimpleName() + ", not a " + f.getClass().getSimpleName() + ". this is a bug");
        }
        this.mof = (MoleculeObjectiveFunction)f;
        if (!(this.mof.efunc instanceof ResidueForcefieldEnergy)) {
            throw new Error("energy function should be a " + ResidueForcefieldEnergy.class.getSimpleName() + ", not a " + this.mof.efunc.getClass().getSimpleName() + ". this is a bug");
        }
        this.efunc = (ResidueForcefieldEnergy)this.mof.efunc;
        if (this.efunc.isBroken) {
            return;
        }
        stream.getContext().attachCurrentThread();
        int[] atomOffsetsByResIndex = new int[this.efunc.residues.size()];
        Arrays.fill(atomOffsetsByResIndex, -1);
        int atomOffset = 0;
        int numAtoms = 0;
        for (int i = 0; i < this.efunc.residues.size(); ++i) {
            Residue res = (Residue)this.efunc.residues.get(i);
            atomOffsetsByResIndex[i] = atomOffset;
            atomOffset += 3 * res.atoms.size();
            numAtoms += res.atoms.size();
        }
        this.coords = stream.doubleBuffers.checkout(numAtoms * 3);
        this.dihedrals = new ArrayList<Dihedral>();
        ObjectiveFunction.DofBounds dofBounds = new ObjectiveFunction.DofBounds(this.mof.getConstraints());
        IdentityHashMap<Residue, int[]> cache = new IdentityHashMap<Residue, int[]>();
        int maxNumAtoms = 0;
        for (int d2 = 0; d2 < this.mof.getNumDOFs(); ++d2) {
            DegreeOfFreedom dof = this.mof.pmol.dofs.get(d2);
            if (!(dof instanceof FreeDihedral)) {
                throw new Error("degree-of-freedom type " + dof.getClass().getSimpleName() + " not yet supported by CCD kernel. Use CPU minimizer with GPU energy function instead");
            }
            Dihedral dihedral = new Dihedral(this, d2, (FreeDihedral)dof);
            dihedral.resIndex = this.efunc.residues.findIndex(dihedral.res);
            if (dihedral.resIndex < 0) continue;
            dihedral.resPairIndices = (int[])cache.get(dihedral.res);
            if (dihedral.resPairIndices == null) {
                dihedral.resPairIndices = this.efunc.makeResPairIndicesSubset(dof.getResidue());
                cache.put(dihedral.res, dihedral.resPairIndices);
            }
            dihedral.xdmin = dofBounds.getMin(d2);
            dihedral.xdmax = dofBounds.getMax(d2);
            maxNumAtoms = Math.max(maxNumAtoms, dihedral.res.atoms.size());
            this.dihedrals.add(dihedral);
        }
        int numRotatedAtoms = 0;
        int numResPairs = 0;
        for (Dihedral dihedral : this.dihedrals) {
            numRotatedAtoms += MathTools.roundUpToMultiple(dihedral.rotatedIndices.length, 4);
            numResPairs += MathTools.roundUpToMultiple(dihedral.resPairIndices.length, 2);
        }
        int numAtomPairs = 0;
        for (ResPairCache.ResPair resPair : this.efunc.resPairs) {
            numAtomPairs += resPair.info.numAtomPairs;
        }
        this.data = stream.byteBuffers.checkout(32 + 8 * this.dihedrals.size() + 8 * this.efunc.resPairs.length + 48 * this.dihedrals.size() + 2 * numRotatedAtoms + 4 * numResPairs + 40 * this.efunc.resPairs.length + 80 * numAtomPairs);
        ByteBuffer byteBuffer = this.data.getHostBuffer();
        ForcefieldParams ffparams = this.efunc.resPairCache.ffparams;
        int flags = ffparams.hElect ? 1 : 0;
        flags <<= 1;
        flags |= ffparams.hVDW ? 1 : 0;
        flags <<= 1;
        flags |= ffparams.distDepDielect ? 1 : 0;
        flags <<= 1;
        byteBuffer.putInt(flags |= ffparams.solvationForcefield == ForcefieldParams.SolvationForcefield.EEF1 ? 1 : 0);
        byteBuffer.putInt(this.dihedrals.size());
        byteBuffer.putInt(this.efunc.resPairs.length);
        byteBuffer.putInt(maxNumAtoms);
        double coulombFactor = 332.0 / ffparams.dielectric;
        double scaledCoulombFactor = coulombFactor * ffparams.forcefld.coulombScaling;
        byteBuffer.putDouble(coulombFactor);
        byteBuffer.putDouble(scaledCoulombFactor);
        int dihedralOffsetsPos = byteBuffer.position();
        for (int d3 = 0; d3 < this.dihedrals.size(); ++d3) {
            byteBuffer.putLong(0L);
        }
        int resPairsOffsetsPos = byteBuffer.position();
        for (d = 0; d < this.efunc.resPairs.length; ++d) {
            byteBuffer.putLong(0L);
        }
        for (d = 0; d < this.dihedrals.size(); ++d) {
            int i;
            Dihedral dihedral = this.dihedrals.get(d);
            byteBuffer.putLong(dihedralOffsetsPos + d * 8, byteBuffer.position());
            byteBuffer.putInt(dihedral.resIndex);
            byteBuffer.putInt(dihedral.res.atoms.size());
            byteBuffer.putLong(atomOffsetsByResIndex[dihedral.resIndex]);
            for (j = 0; j < dihedral.dihedralIndices.length; ++j) {
                byteBuffer.putShort((short)(dihedral.dihedralIndices[j] * 3));
            }
            byteBuffer.putInt(dihedral.rotatedIndices.length);
            byteBuffer.putInt(dihedral.resPairIndices.length);
            byteBuffer.putDouble(Math.toRadians(dihedral.xdmin));
            byteBuffer.putDouble(Math.toRadians(dihedral.xdmax));
            int n = MathTools.roundUpToMultiple(dihedral.rotatedIndices.length, 4);
            for (i = 0; i < n; ++i) {
                if (i < dihedral.rotatedIndices.length) {
                    byteBuffer.putShort((short)(dihedral.rotatedIndices[i] * 3));
                    continue;
                }
                byteBuffer.putShort((short)0);
            }
            n = MathTools.roundUpToMultiple(dihedral.resPairIndices.length, 2);
            for (i = 0; i < n; ++i) {
                if (i < dihedral.resPairIndices.length) {
                    byteBuffer.putInt(dihedral.resPairIndices[i]);
                    continue;
                }
                byteBuffer.putInt(0);
            }
        }
        for (int i = 0; i < this.efunc.resPairs.length; ++i) {
            ResPairCache.ResPair resPair = this.efunc.resPairs[i];
            byteBuffer.putLong(resPairsOffsetsPos + i * 8, byteBuffer.position());
            byteBuffer.putLong(resPair.info.numAtomPairs);
            byteBuffer.putLong(atomOffsetsByResIndex[this.efunc.residues.findIndexOrThrow(resPair.res1)]);
            byteBuffer.putLong(atomOffsetsByResIndex[this.efunc.residues.findIndexOrThrow(resPair.res2)]);
            byteBuffer.putDouble(resPair.weight);
            byteBuffer.putDouble(resPair.offset + resPair.solvEnergy);
            for (j = 0; j < resPair.info.numAtomPairs; ++j) {
                byteBuffer.putLong(resPair.info.flags[j]);
            }
            for (int k = 0; k < resPair.info.numPrecomputedPerAtomPair; ++k) {
                for (int j2 = 0; j2 < resPair.info.numAtomPairs; ++j2) {
                    byteBuffer.putDouble(resPair.info.precomputed[j2 * resPair.info.numPrecomputedPerAtomPair + k]);
                }
            }
        }
        byteBuffer.flip();
        this.data.uploadAsync();
        this.xin = stream.doubleBuffers.checkout(this.dihedrals.size());
        this.out = stream.doubleBuffers.checkout(this.dihedrals.size() + 1);
        this.func = this.makeFunction("ccd");
        this.func.numBlocks = 1;
        this.func.sharedMemCalc = blockThreads -> blockThreads * 8 + this.dihedrals.size() * 8 + this.dihedrals.size() * 8 + this.dihedrals.size() * 8;
        this.func.setArgs(Pointer.to((NativePointerObject[])new NativePointerObject[]{this.data.getDevicePointer(), this.coords.getDevicePointer(), this.xin.getDevicePointer(), this.out.getDevicePointer()}));
        this.func.blockThreads = this.func.getBestBlockThreads(ResidueCudaCCDMinimizer.blockThreads);
    }

    @Override
    public Minimizer.Result minimizeFromCenter() {
        DoubleMatrix1D x = DoubleFactory1D.dense.make(this.dihedrals.size());
        for (Dihedral dihedral : this.dihedrals) {
            x.set(dihedral.d, (dihedral.xdmax + dihedral.xdmin) / 2.0);
        }
        return this.minimizeFrom(x);
    }

    @Override
    public Minimizer.Result minimizeFrom(DoubleMatrix1D x) {
        if (this.efunc.isBroken) {
            return new Minimizer.Result(null, Double.POSITIVE_INFINITY);
        }
        for (Dihedral dihedral : this.dihedrals) {
            this.mof.setDOF(dihedral.d, x.get(dihedral.d));
        }
        DoubleBuffer coordsbuf = this.coords.getHostBuffer();
        coordsbuf.clear();
        for (Object res : this.efunc.residues) {
            coordsbuf.put(((Residue)res).coords);
        }
        coordsbuf.clear();
        this.coords.uploadAsync();
        DoubleBuffer doubleBuffer = this.xin.getHostBuffer();
        doubleBuffer.clear();
        for (Dihedral dihedral : this.dihedrals) {
            doubleBuffer.put(Math.toRadians(x.get(dihedral.d)));
        }
        doubleBuffer.clear();
        this.xin.uploadAsync();
        this.func.runAsync();
        DoubleBuffer outbuf = this.out.downloadSync();
        outbuf.rewind();
        x = DoubleFactory1D.dense.make(this.mof.getNumDOFs());
        for (Dihedral dihedral : this.dihedrals) {
            x.set(dihedral.d, Math.toDegrees(outbuf.get()));
        }
        Minimizer.Result result = new Minimizer.Result(x, outbuf.get());
        this.mof.setDOFs(result.dofValues);
        return result;
    }

    @Override
    public void clean() {
        GpuStream stream = this.getStream();
        if (this.data != null) {
            stream.byteBuffers.release(this.data);
            this.data = null;
        }
        if (this.coords != null) {
            stream.doubleBuffers.release(this.coords);
            this.coords = null;
        }
        if (this.xin != null) {
            stream.doubleBuffers.release(this.xin);
            this.xin = null;
        }
        if (this.out != null) {
            stream.doubleBuffers.release(this.out);
            this.out = null;
        }
        this.streams.release(stream);
    }

    private class Dihedral {
        public final int d;
        public final Residue res;
        public final int[] dihedralIndices;
        public final int[] rotatedIndices;
        public int resIndex;
        public int[] resPairIndices;
        public double xdmin;
        public double xdmax;

        public Dihedral(ResidueCudaCCDMinimizer residueCudaCCDMinimizer, int d, FreeDihedral diehedral) {
            this.d = d;
            this.res = diehedral.getResidue();
            this.dihedralIndices = this.res.template.getDihedralDefiningAtoms(diehedral.getDihedralNumber());
            this.rotatedIndices = this.toArray(this.res.template.getDihedralRotatedAtoms(diehedral.getDihedralNumber()));
        }

        private int[] toArray(List<Integer> list) {
            int[] array = new int[list.size()];
            for (int i = 0; i < list.size(); ++i) {
                array[i] = list.get(i);
            }
            return array;
        }
    }
}

