/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.kstar;

import edu.duke.cs.osprey.astar.ConfTree;
import edu.duke.cs.osprey.astar.FullAStarNode;
import edu.duke.cs.osprey.confspace.HigherTupleFinder;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.kstar.KSSearchProblem;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import java.util.ArrayList;

public class KAStarConfTree
extends ConfTree<FullAStarNode> {
    protected KSSearchProblem sp;
    protected PruningMatrix reducedPmat;
    protected PruningMatrix panPmat;
    protected ArrayList<ArrayList<Integer>> panUprunedRCsAtPos;
    protected boolean energyLB = true;

    public KAStarConfTree(KSSearchProblem reducedSP, PruningMatrix reducedPmat, PruningMatrix panPmat) {
        super(new FullAStarNode.Factory(reducedPmat.getNumPos()), reducedSP, reducedPmat);
        this.sp = reducedSP;
        this.reducedPmat = reducedPmat;
        this.panPmat = panPmat;
        this.energyLB = reducedSP.contSCFlex;
        this.initVars();
    }

    private void initVars() {
        this.emat = this.sp.getReducedEnergyMatrix();
        this.panUprunedRCsAtPos = new ArrayList(this.panPmat.getNumPos());
        for (int index = 0; index < this.panPmat.getNumPos(); ++index) {
            ArrayList<Integer> tmp = this.panPmat.unprunedRCsAtPos(index);
            tmp.trimToSize();
            this.panUprunedRCsAtPos.add(tmp);
        }
    }

    protected ArrayList<Integer> getUndefinedPos(RCTuple definedTuple) {
        ArrayList<Integer> ans = this.sp.getMaxPosNums();
        ArrayList<Integer> definedPos = new ArrayList<Integer>(definedTuple.pos.size());
        for (int pos : definedTuple.pos) {
            definedPos.add(this.sp.posNums.get(pos));
        }
        ans.removeAll(definedPos);
        ans.trimToSize();
        return ans;
    }

    protected ArrayList<Integer> allowedRCsAtLevel(int level, int[] partialConf, ArrayList<Integer> undefinedPos) {
        ArrayList<Integer> allowedRCs;
        if (undefinedPos.contains(level)) {
            allowedRCs = this.panUprunedRCsAtPos.get(level);
        } else {
            allowedRCs = new ArrayList();
            int level2 = this.sp.posNums.indexOf(level);
            if (partialConf[level2] != -1) {
                allowedRCs.add(partialConf[level2]);
            } else {
                for (int i : this.unprunedRCsAtPos[level2]) {
                    allowedRCs.add(i);
                }
            }
        }
        return allowedRCs;
    }

    protected double RCContribution(int level, int rc, RCTuple definedTuple, int[] partialConf, ArrayList<Integer> undefinedPos) {
        double rcContrib = this.sp.getEnergyMatrix().getOneBody(level, rc);
        for (int level2 = 0; level2 < this.sp.confSpace.numPos; ++level2) {
            if (undefinedPos.contains(level2) && level2 >= level) continue;
            double levelBestE = this.energyLB ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
            ArrayList<Integer> allowedRCs = this.allowedRCsAtLevel(level2, partialConf, undefinedPos);
            for (int rc2 : allowedRCs) {
                double interactionE = this.sp.getEnergyMatrix().getPairwise(level, rc, level2, rc2);
                double higherOrderE = this.higherOrderContrib(level, rc, level2, rc2, partialConf, undefinedPos);
                levelBestE = this.energyLB ? Math.min(levelBestE, interactionE) : Math.max(levelBestE, interactionE += higherOrderE);
            }
            rcContrib += levelBestE;
        }
        return rcContrib;
    }

    protected double higherOrderContrib(int pos1, int rc1, int pos2, int rc2, int[] partialConf, ArrayList<Integer> undefinedPos) {
        HigherTupleFinder<Double> htf = this.sp.getEnergyMatrix().getHigherOrderTerms(pos1, rc1, pos2, rc2);
        if (htf == null) {
            return 0.0;
        }
        RCTuple curPair = new RCTuple(pos1, rc1, pos2, rc2);
        return this.higherOrderContrib(htf, curPair, partialConf, undefinedPos);
    }

    double higherOrderContrib(HigherTupleFinder<Double> htf, RCTuple startingTuple, int[] partialConf, ArrayList<Integer> undefinedPos) {
        double contrib = 0.0;
        int startingLevel = startingTuple.pos.get(startingTuple.pos.size() - 1);
        for (int iPos : htf.getInteractingPos()) {
            if (!this.posComesBefore(iPos, startingLevel, undefinedPos)) continue;
            double levelBestE = this.energyLB ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
            ArrayList<Integer> allowedRCs = this.allowedRCsAtLevel(iPos, partialConf, undefinedPos);
            for (int rc : allowedRCs) {
                RCTuple augTuple = startingTuple.addRC(iPos, rc);
                double interactionE = htf.getInteraction(iPos, rc);
                HigherTupleFinder<Double> htf2 = htf.getHigherInteractions(iPos, rc);
                if (htf2 != null) {
                    interactionE += this.higherOrderContrib(htf2, augTuple, partialConf, undefinedPos);
                }
                levelBestE = this.energyLB ? Math.min(levelBestE, interactionE) : Math.max(levelBestE, interactionE);
            }
            contrib += levelBestE;
        }
        return contrib;
    }

    protected boolean posComesBefore(int pos1, int pos2, ArrayList<Integer> undefinedPos) {
        if (!undefinedPos.contains(pos2)) {
            return pos1 < pos2 && !undefinedPos.contains(pos1);
        }
        return pos1 < pos2 || !undefinedPos.contains(pos1);
    }

    @Override
    protected double scoreConfDifferential(FullAStarNode parentNode, int childPos, int childRc) {
        this.assertSplitPositions();
        int[] conf = parentNode.getNodeAssignments();
        if (childPos >= 0) {
            assert (conf[childPos] < 0);
            System.arraycopy(conf, 0, this.childConf, 0, this.numPos);
            this.childConf[childPos] = childRc;
            conf = this.childConf;
        }
        double ans = this.scoreNode(conf);
        return ans;
    }

    @Override
    protected double scoreNode(int[] partialConf) {
        if (this.traditionalScore) {
            this.rcTuple.set(partialConf);
            double score = this.emat.getConstTerm() + this.emat.getInternalEnergy(this.rcTuple);
            ArrayList<Integer> undefinedLevels = this.getUndefinedPos(this.rcTuple);
            for (int level : undefinedLevels) {
                double bestInteractionE = this.energyLB ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
                ArrayList<Integer> rcsAtUndefinedLevel = this.panUprunedRCsAtPos.get(level);
                for (int rc : rcsAtUndefinedLevel) {
                    double rcContribution = this.RCContribution(level, rc, this.rcTuple, partialConf, undefinedLevels);
                    bestInteractionE = this.energyLB ? Math.min(bestInteractionE, rcContribution) : Math.max(bestInteractionE, rcContribution);
                }
                score += bestInteractionE;
            }
            return score;
        }
        throw new RuntimeException("Advanced A* scoring methods not implemented yet!");
    }

    public double confEnergyBound(int[] partialConf) {
        return this.scoreNode(partialConf);
    }
}

