/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.pruning;

import edu.duke.cs.osprey.confspace.ConfSpace;
import edu.duke.cs.osprey.confspace.HigherTupleFinder;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SearchProblem;
import edu.duke.cs.osprey.confspace.TupleEnumerator;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.epic.EPICMatrix;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import edu.duke.cs.osprey.pruning.PruningMethod;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.TreeSet;

public class Pruner {
    boolean typeDep;
    double boundsThreshold;
    double pruningInterval;
    PruningMatrix pruneMat;
    PruningMatrix competitorPruneMat;
    EnergyMatrix emat;
    EPICMatrix epicMat;
    ConfSpace confSpace;
    TupleEnumerator tupEnum;
    boolean useEPIC;
    static int triplesNumPartners = 5;
    private boolean verbose = true;

    public Pruner(SearchProblem searchSpace, boolean typeDep, double boundsThreshold, double pruningInterval, boolean useEPIC, boolean useTupExp) {
        this(searchSpace, searchSpace.pruneMat, typeDep, boundsThreshold, pruningInterval, useEPIC, useTupExp);
    }

    public Pruner(SearchProblem searchSpace, PruningMatrix pruneMat, boolean typeDep, double boundsThreshold, double pruningInterval, boolean useEPIC, boolean useTupExp) {
        this.pruneMat = pruneMat;
        this.epicMat = searchSpace.epicMat;
        this.confSpace = searchSpace.confSpace;
        this.competitorPruneMat = searchSpace.competitorPruneMat;
        if (this.competitorPruneMat == null) {
            this.competitorPruneMat = pruneMat;
        }
        this.typeDep = typeDep;
        this.boundsThreshold = boundsThreshold;
        this.pruningInterval = pruningInterval;
        this.useEPIC = useEPIC;
        this.emat = useTupExp ? searchSpace.tupExpEMat : searchSpace.emat;
        if (this.emat == null) {
            throw new Error("Pruner found no EnergyMatrix in SearchSpace. This is a bug");
        }
        if (useTupExp && useEPIC) {
            throw new RuntimeException("ERROR: Can't prune with both EPIC and tup-exp at the same time");
        }
        this.tupEnum = new TupleEnumerator(pruneMat, this.emat, searchSpace.confSpace.numPos);
    }

    public void setVerbose(boolean val) {
        this.verbose = val;
    }

    public boolean prune(String methodName) {
        return this.prune(PruningMethod.getMethod(methodName));
    }

    boolean prune(PruningMethod method) {
        boolean prunedSomethingThisCycle;
        if (this.verbose) {
            System.out.println("Starting pruning with " + method.name());
        }
        if (!method.useCompetitor) {
            throw new RuntimeException("ERROR: Bounds pruning not supported yet");
        }
        boolean prunedSomething = false;
        do {
            prunedSomethingThisCycle = false;
            ArrayList<RCTuple> candidates = this.enumerateCandidates(method);
            block1: for (RCTuple cand : candidates) {
                double contELB = 0.0;
                if (this.useEPIC && cand.pos.size() > 1) {
                    contELB = this.epicMat.minimizeEnergy(cand, false);
                }
                if (this.pruneMat.isPruned(cand)) continue;
                for (RCTuple competitor : this.competitorPruneMat.unprunedRCTuplesAtPos(cand.pos)) {
                    if (cand.isSameTuple(competitor) && contELB == 0.0 || this.typeDep && !this.resTypesMatch(cand, competitor) || !this.canPrune(cand, competitor, method.cst, contELB)) continue;
                    this.pruneMat.markAsPruned(cand);
                    prunedSomething = true;
                    prunedSomethingThisCycle = true;
                    continue block1;
                }
            }
        } while (prunedSomethingThisCycle);
        return prunedSomething;
    }

    public ArrayList<RCTuple> enumerateCandidates(PruningMethod method) {
        if (method.numPos <= 2) {
            return this.tupEnum.enumerateUnprunedTuples(method.numPos);
        }
        if (method.numPos == 3) {
            ArrayList<ArrayList<Integer>> posTriples = this.tupEnum.topPositionTriples(triplesNumPartners);
            return this.tupEnum.enumerateUnprunedTuples(posTriples);
        }
        throw new RuntimeException("ERROR: Number of positions not currently supported for pruning: " + method.numPos);
    }

    boolean resTypesMatch(RCTuple tup1, RCTuple tup2) {
        int numPosInTup = tup1.pos.size();
        for (int indexInTup = 0; indexInTup < numPosInTup; ++indexInTup) {
            int pos1 = tup1.pos.get(indexInTup);
            int rc1 = tup1.RCs.get(indexInTup);
            String type1 = this.confSpace.posFlex.get((int)pos1).RCs.get((int)rc1).AAType;
            int pos2 = tup2.pos.get(indexInTup);
            int rc2 = tup2.RCs.get(indexInTup);
            String type2 = this.confSpace.posFlex.get((int)pos2).RCs.get((int)rc2).AAType;
            if (type1.equalsIgnoreCase(type2)) continue;
            return false;
        }
        return true;
    }

    boolean canPrune(RCTuple cand, RCTuple comp, PruningMethod.CheckSumType checkSumType, double contELB) {
        double pairClashInterval = 10.0;
        double tripleClashInterval = 15.0;
        if (contELB > this.pruningInterval) {
            return true;
        }
        EnergyMatrix emat = this.emat;
        PruningMatrix pruneMat = this.pruneMat;
        ArrayList<Integer> candpos = cand.pos;
        ArrayList<Integer> candRCs = cand.RCs;
        ArrayList<Integer> comppos = cand.pos;
        ArrayList<Integer> compRCs = comp.RCs;
        int numCandPos = candpos.size();
        ArrayList<Integer> unprunedRCs = new ArrayList<Integer>(64);
        RCTuple candAndExtra = new RCTuple();
        ArrayList<Integer> candAndExtraPos = candAndExtra.pos;
        ArrayList<Integer> candAndExtraRCs = candAndExtra.RCs;
        candAndExtra.set(cand);
        candAndExtraPos.add(-1);
        candAndExtraRCs.add(-1);
        int extraPosIndex = candAndExtraPos.size() - 1;
        int extraRCIndex = candAndExtraRCs.size() - 1;
        double checkSum = emat.getInternalEnergy(cand);
        checkSum += contELB;
        checkSum -= emat.getInternalEnergy(comp);
        boolean useHigherOrder = emat.hasHigherOrderTerms();
        if (checkSumType != PruningMethod.CheckSumType.GOLDSTEIN) {
            throw new RuntimeException("ERROR: Not supporting indirect and conf-splitting pruning yet...");
        }
        assert (this.pruningInterval < Double.POSITIVE_INFINITY);
        int numPos = this.confSpace.numPos;
        for (int posWit = 0; posWit < numPos; ++posWit) {
            if (candpos.contains(posWit)) continue;
            pruneMat.unprunedRCsAtPos(unprunedRCs, posWit);
            if (unprunedRCs.isEmpty()) {
                return true;
            }
            double minDiff = Double.POSITIVE_INFINITY;
            for (int rcWit : unprunedRCs) {
                candAndExtraPos.set(extraPosIndex, posWit);
                candAndExtraRCs.set(extraRCIndex, rcWit);
                if (pruneMat.isPruned(candAndExtra)) continue;
                double diff = 0.0;
                for (int i = 0; i < numCandPos; ++i) {
                    int posCand = candpos.get(i);
                    int rcCand = candRCs.get(i);
                    int rcComp = compRCs.get(i);
                    if ((diff += emat.getPairwise(posWit, rcWit, posCand, rcCand) - emat.getPairwise(posWit, rcWit, posCand, rcComp)) == Double.POSITIVE_INFINITY) break;
                }
                if (useHigherOrder) {
                    diff += this.minInteractionDiffHigher(cand, comp, posWit, rcWit);
                }
                minDiff = Math.min(minDiff, diff);
            }
            if (minDiff == Double.POSITIVE_INFINITY) {
                return true;
            }
            checkSum += minDiff;
        }
        return checkSum > this.pruningInterval;
    }

    double minInteractionDiffHigher(RCTuple cand, RCTuple comp, int pos, int rc) {
        ArrayList<HigherTupleFinder<Double>> candHigher = new ArrayList<HigherTupleFinder<Double>>();
        ArrayList<HigherTupleFinder<Double>> compHigher = new ArrayList<HigherTupleFinder<Double>>();
        for (int indexInTup = 0; indexInTup < cand.pos.size(); ++indexInTup) {
            HigherTupleFinder htfComp;
            int pos2 = cand.pos.get(indexInTup);
            int rc2 = cand.RCs.get(indexInTup);
            int rc2Comp = comp.RCs.get(indexInTup);
            HigherTupleFinder htfCand = this.emat.getHigherOrderTerms(pos, rc, pos2, rc2);
            if (htfCand != null) {
                candHigher.add(htfCand);
            }
            if ((htfComp = this.emat.getHigherOrderTerms(pos, rc, pos2, rc2Comp)) == null) continue;
            compHigher.add(htfComp);
        }
        if (candHigher.isEmpty() && compHigher.isEmpty()) {
            return 0.0;
        }
        return this.higherOrderContribGoldstein(cand, comp, pos, candHigher, compHigher);
    }

    double higherOrderContribGoldstein(RCTuple cand, RCTuple comp, int pos, ArrayList<HigherTupleFinder<Double>> candHigher, ArrayList<HigherTupleFinder<Double>> compHigher) {
        double contrib = this.higherOrderContribInternal(cand, candHigher) - this.higherOrderContribInternal(comp, compHigher);
        TreeSet<Integer> interactingPos = new TreeSet<Integer>();
        for (HigherTupleFinder<Double> htf : candHigher) {
            for (int iPos : htf.getInteractingPos()) {
                if (cand.pos.contains(iPos) || iPos >= pos) continue;
                interactingPos.add(iPos);
            }
        }
        for (HigherTupleFinder<Double> htf : compHigher) {
            for (int iPos : htf.getInteractingPos()) {
                if (cand.pos.contains(iPos) || iPos >= pos) continue;
                interactingPos.add(iPos);
            }
        }
        Iterator<HigherTupleFinder<Double>> iterator2 = interactingPos.iterator();
        while (iterator2.hasNext()) {
            int iPos = (Integer)((Object)iterator2.next());
            double levelBestE = Double.POSITIVE_INFINITY;
            ArrayList<Integer> allowedRCs = this.pruneMat.unprunedRCsAtPos(iPos);
            for (int rc : allowedRCs) {
                double interactionE = 0.0;
                for (HigherTupleFinder<Double> htf : candHigher) {
                    interactionE += htf.getInteraction(iPos, rc).doubleValue();
                    if (htf.getHigherInteractions(iPos, rc) == null) continue;
                    throw new UnsupportedOperationException("ERROR: Not supporting energy >triples in DEE");
                }
                for (HigherTupleFinder<Double> htf : compHigher) {
                    interactionE -= htf.getInteraction(iPos, rc).doubleValue();
                    if (htf.getHigherInteractions(iPos, rc) == null) continue;
                    throw new UnsupportedOperationException("ERROR: Not supporting energy >triples in DEE");
                }
                levelBestE = Math.min(levelBestE, interactionE);
            }
            contrib += levelBestE;
        }
        return contrib;
    }

    double higherOrderContribInternal(RCTuple tup, ArrayList<HigherTupleFinder<Double>> htfList) {
        double E = 0.0;
        for (HigherTupleFinder<Double> htf : htfList) {
            for (int posCount = 0; posCount < tup.pos.size(); ++posCount) {
                E += htf.getInteraction(tup.pos.get(posCount), tup.RCs.get(posCount)).doubleValue();
            }
        }
        return E / 2.0;
    }

    boolean canPrune(RCTuple cand, PruningMethod.CheckSumType checkSumType) {
        double checkSum = this.emat.getInternalEnergy(cand);
        if (checkSumType == PruningMethod.CheckSumType.BOUNDS) {
            for (int level = 0; level < this.confSpace.numPos; ++level) {
                double resContribLB = Double.POSITIVE_INFINITY;
                if (!cand.pos.contains(level)) {
                    for (int rc : this.pruneMat.unprunedRCsAtPos(level)) {
                        resContribLB = Math.min(resContribLB, this.RCContributionLB(level, rc, cand));
                    }
                }
                checkSum += resContribLB;
            }
        } else {
            throw new RuntimeException("ERROR: Unrecognized checksum type for non-competitive pruning: " + checkSumType.name());
        }
        return checkSum > this.boundsThreshold + this.pruningInterval;
    }

    double RCContributionLB(int level, int rc, RCTuple definedTuple) {
        double rcContrib = 0.0;
        for (int level2 = 0; level2 < level; ++level2) {
            if (!definedTuple.pos.contains(level2) && level2 >= level) continue;
            double levelBestE = Double.POSITIVE_INFINITY;
            ArrayList<Integer> allowedRCs = null;
            if (definedTuple.pos.contains(level2)) {
                int index = definedTuple.pos.indexOf(level2);
                int definedRC = definedTuple.RCs.get(index);
                allowedRCs = new ArrayList();
                allowedRCs.add(definedRC);
            } else {
                allowedRCs = this.pruneMat.unprunedRCsAtPos(level2);
            }
            for (int rc2 : allowedRCs) {
                double interactionE = this.emat.getPairwise(level, rc, level2, rc2);
                levelBestE = Math.min(levelBestE, interactionE);
            }
            rcContrib += levelBestE;
        }
        return rcContrib;
    }

    public void pruneSteric(double stericThresh) {
        if (this.verbose) {
            System.out.println("Starting steric pruning.");
        }
        for (int numBodies = 1; numBodies <= 2; ++numBodies) {
            int numPruned = 0;
            ArrayList<RCTuple> candList = this.tupEnum.enumerateUnprunedTuples(numBodies);
            for (RCTuple cand : candList) {
                double E = numBodies == 1 ? this.emat.getOneBody(cand.pos.get(0), cand.RCs.get(0)).doubleValue() : this.emat.getPairwise(cand.pos.get(0), cand.RCs.get(0), cand.pos.get(1), cand.RCs.get(1)).doubleValue();
                if (!(E > stericThresh)) continue;
                this.pruneMat.markAsPruned(cand);
                ++numPruned;
            }
            if (!this.verbose) continue;
            System.out.println("Pruned " + numPruned + " in " + numBodies + "-body steric pruning");
        }
    }
}

