/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.pruning;

import edu.duke.cs.osprey.confspace.ParametricMolecule;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.dof.DegreeOfFreedom;
import edu.duke.cs.osprey.energy.ResInterGen;
import edu.duke.cs.osprey.energy.ResidueInteractions;
import edu.duke.cs.osprey.parallelism.TaskExecutor;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import edu.duke.cs.osprey.structure.Atom;
import edu.duke.cs.osprey.structure.AtomConnectivity;
import edu.duke.cs.osprey.structure.AtomNeighbors;
import edu.duke.cs.osprey.structure.Probe;
import edu.duke.cs.osprey.structure.Residue;
import edu.duke.cs.osprey.tools.Progress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.linear.LinearConstraint;
import org.apache.commons.math3.optim.linear.LinearConstraintSet;
import org.apache.commons.math3.optim.linear.LinearObjectiveFunction;
import org.apache.commons.math3.optim.linear.NoFeasibleSolutionException;
import org.apache.commons.math3.optim.linear.Relationship;
import org.apache.commons.math3.optim.linear.SimplexSolver;

public class PLUG {
    public final SimpleConfSpace confSpace;
    public int maxNumIterations = 30;
    public double violationThreshold = 0.01;
    public double gradientDxFactor = 1.0E-4;
    private final Probe probe;
    private final AtomConnectivity connectivity;

    public PLUG(SimpleConfSpace confSpace) {
        this.confSpace = confSpace;
        this.probe = new Probe();
        this.probe.matchTemplates(this.confSpace);
        this.connectivity = new AtomConnectivity.Builder().set15HasNonBonded(false).build();
    }

    public void pruneSingles(PruningMatrix pmat, double tolerance) {
        this.pruneSingles(pmat, tolerance, new TaskExecutor());
    }

    public void pruneSingles(PruningMatrix pmat, double tolerance, TaskExecutor tasks) {
        AtomicLong numSingles = new AtomicLong(0L);
        pmat.forEachUnprunedSingle((pos1, rc1) -> {
            numSingles.incrementAndGet();
            return PruningMatrix.IteratorCommand.Continue;
        });
        Progress progress2 = new Progress(numSingles.get());
        pmat.forEachUnprunedSingle((pos1, rc1) -> {
            tasks.submit(() -> this.shouldPruneTuple(new RCTuple(pos1, rc1), tolerance), shouldPrune -> {
                if (shouldPrune.booleanValue()) {
                    pmat.pruneSingle(pos1, rc1);
                }
                progress2.incrementProgress();
            });
            return PruningMatrix.IteratorCommand.Continue;
        });
        tasks.waitForFinish();
    }

    public void prunePairs(PruningMatrix pmat, double tolerance) {
        this.prunePairs(pmat, tolerance, new TaskExecutor());
    }

    public void prunePairs(PruningMatrix pmat, double tolerance, TaskExecutor tasks) {
        AtomicLong numPairs = new AtomicLong(0L);
        pmat.forEachUnprunedPair((pos1, rc1, pos2, rc2) -> {
            numPairs.incrementAndGet();
            return PruningMatrix.IteratorCommand.Continue;
        });
        Progress progress2 = new Progress(numPairs.get());
        pmat.forEachUnprunedPair((pos1, rc1, pos2, rc2) -> {
            tasks.submit(() -> this.shouldPruneTuple(new RCTuple(pos1, rc1, pos2, rc2), tolerance), shouldPrune -> {
                if (shouldPrune.booleanValue()) {
                    pmat.prunePair(pos1, rc1, pos2, rc2);
                }
                progress2.incrementProgress();
            });
            return PruningMatrix.IteratorCommand.Continue;
        });
        tasks.waitForFinish();
    }

    public void pruneTriples(PruningMatrix pmat, double tolerance) {
        this.pruneTriples(pmat, tolerance, new TaskExecutor());
    }

    public void pruneTriples(PruningMatrix pmat, double tolerance, TaskExecutor tasks) {
        AtomicLong numTriples = new AtomicLong(0L);
        pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
            numTriples.incrementAndGet();
            return PruningMatrix.IteratorCommand.Continue;
        });
        Progress progress2 = new Progress(numTriples.get());
        pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
            tasks.submit(() -> this.shouldPruneTuple(new RCTuple(pos1, rc1, pos2, rc2, pos3, rc3), tolerance), shouldPrune -> {
                if (shouldPrune.booleanValue()) {
                    pmat.pruneTriple(pos1, rc1, pos2, rc2, pos3, rc3);
                }
                progress2.incrementProgress();
            });
            return PruningMatrix.IteratorCommand.Continue;
        });
        tasks.waitForFinish();
    }

    public boolean shouldPruneTuple(RCTuple tuple, double tolerance) {
        ParametricMolecule pmol = this.confSpace.makeMolecule(tuple);
        ResidueInteractions inters = ResInterGen.of(this.confSpace).addIntras(tuple).addInters(tuple).addShell(tuple).make();
        Voxel voxel = new Voxel(this, pmol);
        try {
            List<LinearConstraint> constraints = this.getLinearConstraints(voxel, inters, tolerance);
            if (constraints.isEmpty()) {
                return false;
            }
            new SimplexSolver().optimize(new OptimizationData[]{new SimpleBounds(voxel.min, voxel.max), new LinearConstraintSet(constraints), new LinearObjectiveFunction(new double[pmol.dofs.size()], 0.0)});
            return false;
        }
        catch (NoFeasibleSolutionException ex) {
            return true;
        }
    }

    public List<LinearConstraint> getLinearConstraints(Voxel voxel, ResidueInteractions inters, double tolerance) {
        HashMap<Atom, AtomVoxel> atomVoxels = new HashMap<Atom, AtomVoxel>();
        ArrayList<LinearConstraint> constraints = new ArrayList<LinearConstraint>();
        for (ResidueInteractions.Pair resPair : inters) {
            Residue res1 = voxel.pmol.mol.residues.getOrThrow(resPair.resNum1);
            Residue res2 = voxel.pmol.mol.residues.getOrThrow(resPair.resNum2);
            for (int[] atomPair : this.connectivity.getAtomPairs(res1, res2).getPairs(AtomNeighbors.Type.NONBONDED)) {
                AtomPairVoxel pairVoxel;
                LinearConstraint constraint;
                Atom a1 = res1.atoms.get(atomPair[0]);
                Atom a2 = res2.atoms.get(atomPair[1]);
                AtomVoxel v1 = atomVoxels.computeIfAbsent(a1, key -> new AtomVoxel(this, a1, this.probe, voxel));
                AtomVoxel v2 = atomVoxels.computeIfAbsent(a2, key -> new AtomVoxel(this, a2, this.probe, voxel));
                if (!v1.hasDofs() && !v2.hasDofs() || (constraint = this.getLinearConstraint(pairVoxel = new AtomPairVoxel(this, voxel, v1, v2), tolerance)) == null) continue;
                constraints.add(constraint);
            }
        }
        return constraints;
    }

    public LinearConstraint getLinearConstraint(AtomPairVoxel voxel, double tolerance) {
        BoundaryPoint p = this.findBoundaryNewton(voxel, tolerance);
        if (p == null) {
            return null;
        }
        if (!p.atBoundary()) {
            if (p.violation > 0.0) {
                throw new NoFeasibleSolutionException();
            }
            return null;
        }
        int n = p.dofValues.length;
        ArrayRealVector u = new ArrayRealVector(n);
        double w = 0.0;
        for (int d = 0; d < n; ++d) {
            double g = -p.gradient[d];
            u.setEntry(d, g);
            w += p.dofValues[d] * g;
        }
        return new LinearConstraint((RealVector)u, Relationship.GEQ, w);
    }

    public BoundaryPoint findBoundaryNewton(AtomPairVoxel pairVoxel, double tolerance) {
        Function<double[], Double> f = x -> pairVoxel.getViolation((double[])x, tolerance);
        double[] gout = new double[pairVoxel.numDofs];
        Function<double[], double[]> g = x -> {
            double baseViolation = (Double)f.apply((double[])x);
            for (int d = 0; d < pairVoxel.numDofs; ++d) {
                double dx = this.gradientDxFactor * pairVoxel.width(d);
                gout[d] = (pairVoxel.getViolationAlong(d, x[d], dx, tolerance) - baseViolation) / dx;
            }
            return gout;
        };
        double[] x2 = new double[pairVoxel.numDofs];
        for (int d = 0; d < pairVoxel.numDofs; ++d) {
            x2[d] = pairVoxel.center(d);
        }
        double violation = f.apply(x2);
        if (violation == 0.0) {
            return new BoundaryPoint(x2, violation, g.apply(x2));
        }
        for (int i = 0; i < this.maxNumIterations; ++i) {
            int d;
            double[] grad = g.apply(x2);
            double s = 0.0;
            for (d = 0; d < pairVoxel.numDofs; ++d) {
                s += pairVoxel.width2(d) * grad[d] * grad[d];
            }
            if (s == 0.0) {
                return null;
            }
            for (d = 0; d < pairVoxel.numDofs; ++d) {
                int n = d;
                x2[n] = x2[n] - violation * grad[d] * pairVoxel.width2(d) / s;
            }
            if (pairVoxel.outOfRange(x2)) {
                return new BoundaryPoint(violation);
            }
            violation = f.apply(x2);
            double diff = Math.abs(violation);
            if (!(diff <= this.violationThreshold)) continue;
            return new BoundaryPoint(x2, violation, grad);
        }
        boolean developerIsInvestigatingFrequencyOfThisHappening = false;
        return null;
    }

    private class Voxel {
        final ParametricMolecule pmol;
        final int numDofs;
        final double[] width;
        final double[] width2;
        final double[] min;
        final double[] max;
        final double[] center;

        Voxel(PLUG pLUG, ParametricMolecule pmol) {
            this.pmol = pmol;
            this.numDofs = pmol.dofs.size();
            this.width = new double[this.numDofs];
            this.width2 = new double[this.numDofs];
            this.min = new double[this.numDofs];
            this.max = new double[this.numDofs];
            this.center = new double[this.numDofs];
            for (int d = 0; d < this.numDofs; ++d) {
                this.width[d] = pmol.dofBounds.getWidth(d);
                this.width2[d] = this.width[d] * this.width[d];
                this.min[d] = pmol.dofBounds.getMin(d);
                this.max[d] = pmol.dofBounds.getMax(d);
                this.center[d] = pmol.dofBounds.getCenter(d);
            }
        }

        DegreeOfFreedom getDof(int d) {
            return this.pmol.dofs.get(d);
        }

        void applyDof(int d, double val) {
            this.getDof(d).apply(val);
        }
    }

    private class AtomVoxel {
        final Atom atom;
        final Probe.AtomInfo probeInfo;
        final List<Integer> dofIndices = new ArrayList<Integer>();

        AtomVoxel(PLUG pLUG, Atom atom, Probe probe) {
            this.atom = atom;
            this.probeInfo = probe.getAtomInfo(atom);
        }

        AtomVoxel(PLUG pLUG, Atom atom, Probe probe, Voxel voxel) {
            this(pLUG, atom, probe);
            for (int d = 0; d < voxel.numDofs; ++d) {
                if (voxel.width[d] <= 0.0) continue;
                voxel.applyDof(d, voxel.center[d]);
                double[] start = atom.getCoords();
                voxel.applyDof(d, voxel.center[d] + pLUG.gradientDxFactor * voxel.width[d]);
                double[] stop = atom.getCoords();
                if (Arrays.equals(start, stop)) continue;
                this.dofIndices.add(d);
            }
        }

        public boolean hasDofs() {
            return !this.dofIndices.isEmpty();
        }
    }

    private class AtomPairVoxel {
        final Voxel voxel;
        final Probe.AtomPair probePair;
        final List<Integer> dofIndices = new ArrayList<Integer>();
        final int numDofs;

        AtomPairVoxel(PLUG pLUG, Voxel voxel, AtomVoxel v1, AtomVoxel v2) {
            this.voxel = voxel;
            Probe probe = pLUG.probe;
            Objects.requireNonNull(probe);
            this.probePair = new Probe.AtomPair(probe, v1.atom, v2.atom, v1.probeInfo, v2.probeInfo);
            for (int d : v1.dofIndices) {
                if (this.dofIndices.contains(d)) continue;
                this.dofIndices.add(d);
            }
            for (int d : v2.dofIndices) {
                if (this.dofIndices.contains(d)) continue;
                this.dofIndices.add(d);
            }
            this.numDofs = this.dofIndices.size();
        }

        double min(int d) {
            d = this.dofIndices.get(d);
            return this.voxel.min[d];
        }

        double max(int d) {
            d = this.dofIndices.get(d);
            return this.voxel.max[d];
        }

        double center(int d) {
            d = this.dofIndices.get(d);
            return this.voxel.center[d];
        }

        double width(int d) {
            d = this.dofIndices.get(d);
            return this.voxel.width[d];
        }

        double width2(int d) {
            d = this.dofIndices.get(d);
            return this.voxel.width2[d];
        }

        void applyDof(int d, double val) {
            d = this.dofIndices.get(d);
            this.voxel.applyDof(d, val);
        }

        double getViolation(double[] x, double tolerance) {
            for (int d = 0; d < this.numDofs; ++d) {
                this.applyDof(d, x[d]);
            }
            return this.probePair.getViolation(tolerance);
        }

        double getViolationAlong(int d, double x, double dx, double tolerance) {
            this.applyDof(d, x + dx);
            double violation = this.probePair.getViolation(tolerance);
            this.applyDof(d, x);
            return violation;
        }

        boolean outOfRange(double[] x) {
            for (int d = 0; d < this.numDofs; ++d) {
                if (!(x[d] < this.min(d)) && !(x[d] > this.max(d))) continue;
                return true;
            }
            return false;
        }
    }

    public static class BoundaryPoint {
        double[] dofValues;
        double violation;
        double[] gradient;

        public BoundaryPoint(double[] dofValues, double violation, double[] gradient) {
            this.dofValues = (double[])dofValues.clone();
            this.violation = violation;
            this.gradient = (double[])gradient.clone();
        }

        public BoundaryPoint(double violation) {
            this.dofValues = null;
            this.violation = violation;
            this.gradient = null;
        }

        public boolean atBoundary() {
            return this.gradient != null;
        }

        public String toString() {
            int d;
            StringBuilder buf = new StringBuilder();
            buf.append(String.format("violation: %.3f", this.violation));
            if (this.dofValues != null) {
                for (d = 0; d < this.dofValues.length; ++d) {
                    buf.append(String.format("\n\tx[%d]: %.3f", d, this.dofValues[d]));
                }
            }
            if (this.gradient != null) {
                for (d = 0; d < this.gradient.length; ++d) {
                    buf.append(String.format("\n\tg[%d]: %.3f", d, this.gradient[d]));
                }
            }
            return buf.toString();
        }
    }
}

