/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.markstar.framework;

import edu.duke.cs.osprey.astar.conf.ConfIndex;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.astar.conf.order.AStarOrder;
import edu.duke.cs.osprey.astar.conf.pruning.AStarPruner;
import edu.duke.cs.osprey.astar.conf.scoring.AStarScorer;
import edu.duke.cs.osprey.astar.conf.scoring.MPLPPairwiseHScorer;
import edu.duke.cs.osprey.astar.conf.scoring.PairwiseGScorer;
import edu.duke.cs.osprey.astar.conf.scoring.TraditionalPairwiseHScorer;
import edu.duke.cs.osprey.astar.conf.scoring.mplp.EdgeUpdater;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.NegatedEnergyMatrix;
import edu.duke.cs.osprey.ematrix.UpdatingEnergyMatrix;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.gmec.ConfAnalyzer;
import edu.duke.cs.osprey.kstar.BBKStar;
import edu.duke.cs.osprey.kstar.pfunc.BoltzmannCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.markstar.MARKStarProgress;
import edu.duke.cs.osprey.markstar.framework.MARKStarBoundFastQueues;
import edu.duke.cs.osprey.markstar.framework.MARKStarNode;
import edu.duke.cs.osprey.markstar.framework.StaticBiggestLowerboundDifferenceOrder;
import edu.duke.cs.osprey.parallelism.Parallelism;
import edu.duke.cs.osprey.parallelism.TaskExecutor;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import edu.duke.cs.osprey.tools.MathTools;
import edu.duke.cs.osprey.tools.ObjectPool;
import edu.duke.cs.osprey.tools.Stopwatch;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;

public class MARKStarBoundRigid
implements PartitionFunction {
    private double targetEpsilon = 1.0;
    public boolean debug = false;
    public boolean profileOutput = false;
    private PartitionFunction.Status status = null;
    private Values values = null;
    private int numConfsEnergied = 0;
    private int maxNumConfs = -1;
    private int numConfsScored = 0;
    private int numInternalNodesProcessed = 0;
    private boolean printMinimizedConfs;
    private MARKStarProgress progress;
    public String stateName = String.format("%4f", Math.random());
    private int numPartialMinimizations;
    private ArrayList<Integer> minList;
    private double internalTimeAverage;
    private double leafTimeAverage;
    private double cleanupTime;
    private boolean nonZeroLower;
    private TaskExecutor loopTasks;
    private UpdatingEnergyMatrix corrections;
    private MARKStarNode rootNode;
    private final Queue<MARKStarNode> queue;
    private double epsilonBound = Double.POSITIVE_INFINITY;
    private ConfIndex confIndex;
    public final AStarOrder order;
    public final AStarPruner pruner;
    private RCs RCs;
    private Parallelism parallelism;
    private TaskExecutor internalTasks;
    private TaskExecutor leafTasks;
    private TaskExecutor drillTasks;
    private ObjectPool<ScoreContext> contexts;
    private MARKStarNode.ScorerFactory gscorerFactory;
    private MARKStarNode.ScorerFactory hscorerFactory;
    private AStarScorer hscorer;
    public boolean reduceMinimizations = true;
    private ConfAnalyzer confAnalyzer;
    EnergyMatrix minimizingEmat;
    EnergyMatrix rigidEmat;
    private Stopwatch stopwatch = new Stopwatch().start();
    BigDecimal cumulativeZCorrection = BigDecimal.ZERO;
    BigDecimal ZReductionFromMin = BigDecimal.ZERO;
    BoltzmannCalculator bc = new BoltzmannCalculator(PartitionFunction.decimalPrecision);
    private boolean computedCorrections = false;
    private long loopPartialTime = 0L;
    private Set<String> correctedTuples = Collections.synchronizedSet(new HashSet());
    private BigDecimal stabilityThreshold;
    private double leafTimeSum = 0.0;
    private double internalTimeSum = 0.0;
    private int numLeavesScored = 0;
    private int numInternalScored = 0;

    public void setCorrections(UpdatingEnergyMatrix corrections) {
        this.corrections = corrections;
    }

    public UpdatingEnergyMatrix getCorrections() {
        return this.corrections;
    }

    public void setRCs(RCs rcs) {
        this.RCs = rcs;
    }

    @Override
    public void setReportProgress(boolean showPfuncProgress) {
        this.printMinimizedConfs = true;
    }

    @Override
    public void setConfListener(PartitionFunction.ConfListener val) {
    }

    @Override
    public void setStabilityThreshold(BigDecimal threshold) {
        this.stabilityThreshold = threshold;
    }

    public void setMaxNumConfs(int maxNumConfs) {
        this.maxNumConfs = maxNumConfs;
    }

    @Override
    public void init(double targetEpsilon) {
        this.targetEpsilon = targetEpsilon;
        this.status = PartitionFunction.Status.Estimating;
        this.values = new Values();
    }

    public void init(double epsilon, BigDecimal stabilityThreshold) {
        this.targetEpsilon = epsilon;
        this.status = PartitionFunction.Status.Estimating;
        this.values = new Values();
        this.stabilityThreshold = stabilityThreshold;
    }

    @Override
    public PartitionFunction.Status getStatus() {
        return this.status;
    }

    @Override
    public PartitionFunction.Values getValues() {
        return this.values;
    }

    @Override
    public int getParallelism() {
        return 0;
    }

    @Override
    public int getNumConfsEvaluated() {
        return this.numConfsEnergied;
    }

    public int getNumConfsScored() {
        return this.numConfsScored;
    }

    private int workDone() {
        return this.numInternalNodesProcessed + this.numConfsEnergied + this.numConfsScored + this.numPartialMinimizations;
    }

    @Override
    public void compute(int maxNumConfs) {
        this.debugPrint("Num conformations: " + String.valueOf(this.rootNode.getConfSearchNode().getNumConformations()));
        double lastEps = 1.0;
        int previousConfCount = this.workDone();
        if (!this.nonZeroLower) {
            this.runUntilNonZero();
            this.updateBound();
        }
        while (this.epsilonBound > this.targetEpsilon && this.workDone() - previousConfCount < maxNumConfs && this.isStable(this.stabilityThreshold)) {
            this.debugPrint("Tightening from epsilon of " + this.epsilonBound);
            this.tightenBoundRigid();
            this.debugPrint("Errorbound is now " + this.epsilonBound);
            if (lastEps < this.epsilonBound && this.epsilonBound - lastEps > 0.01) {
                System.err.println("Error. Bounds got looser.");
            }
            lastEps = this.epsilonBound;
        }
        if (!this.isStable(this.stabilityThreshold)) {
            this.status = PartitionFunction.Status.Unstable;
        }
        this.loopTasks.waitForFinish();
        BigDecimal averageReduction = BigDecimal.ZERO;
        int totalMinimizations = this.numConfsEnergied + this.numPartialMinimizations;
        if (totalMinimizations > 0) {
            averageReduction = this.cumulativeZCorrection.divide(new BigDecimal(totalMinimizations), new MathContext(4));
        }
        this.debugPrint(String.format("Average Z reduction per minimization: %12.6e", averageReduction));
        if (this.epsilonBound < this.targetEpsilon) {
            this.status = PartitionFunction.Status.Estimated;
        }
        this.values.qstar = this.rootNode.getLowerBound();
        this.values.pstar = this.rootNode.getUpperBound();
        this.values.qprime = this.rootNode.getUpperBound();
    }

    private void debugPrint(String s) {
        if (this.debug) {
            System.out.println(s);
        }
    }

    private void profilePrint(String s) {
        if (this.profileOutput) {
            System.out.println(s);
        }
    }

    @Override
    public void compute() {
        this.compute(Integer.MAX_VALUE);
    }

    @Override
    public PartitionFunction.Result makeResult() {
        PartitionFunction.Result result = new PartitionFunction.Result(this.getStatus(), this.getValues(), this.getNumConfsEvaluated());
        return result;
    }

    public static MARKStarBoundFastQueues makeFromConfSpaceInfo(BBKStar.ConfSpaceInfo info2, RCs rcs) {
        throw new UnsupportedOperationException("MARK* is not yet integrated into BBK*. Coming soon!");
    }

    public MARKStarBoundRigid(SimpleConfSpace confSpace, EnergyMatrix rigidEmat, EnergyMatrix minimizingEmat, ConfEnergyCalculator minimizingConfEcalc, RCs rcs, Parallelism parallelism) {
        this.queue = new PriorityQueue<MARKStarNode>();
        this.gscorerFactory = emats -> new PairwiseGScorer(emats);
        EdgeUpdater updater = new EdgeUpdater();
        this.hscorerFactory = emats -> new MPLPPairwiseHScorer(updater, emats, 1, 1.0E-4);
        this.rootNode = MARKStarNode.makeRoot(confSpace, rigidEmat, minimizingEmat, rcs, this.gscorerFactory.make(minimizingEmat), this.hscorerFactory.make(minimizingEmat), this.gscorerFactory.make(rigidEmat), new TraditionalPairwiseHScorer(new NegatedEnergyMatrix(confSpace, rigidEmat), rcs), true);
        this.confIndex = new ConfIndex(rcs.getNumPos());
        this.minimizingEmat = minimizingEmat;
        this.rigidEmat = rigidEmat;
        this.RCs = rcs;
        this.order = new StaticBiggestLowerboundDifferenceOrder();
        this.order.setScorers(this.gscorerFactory.make(minimizingEmat), this.hscorerFactory.make(minimizingEmat));
        this.pruner = null;
        this.contexts = new ObjectPool<ScoreContext>(lingored -> {
            ScoreContext context = new ScoreContext();
            context.index = new ConfIndex(rcs.getNumPos());
            context.gscorer = this.gscorerFactory.make(minimizingEmat);
            context.hscorer = this.hscorerFactory.make(minimizingEmat);
            context.rigidscorer = this.gscorerFactory.make(rigidEmat);
            context.negatedhscorer = this.hscorerFactory.make(new NegatedEnergyMatrix(confSpace, rigidEmat));
            context.ecalc = minimizingConfEcalc;
            return context;
        });
        this.progress = new MARKStarProgress(this.RCs.getNumPos());
        this.confAnalyzer = new ConfAnalyzer(minimizingConfEcalc);
        this.setParallelism(parallelism);
        this.updateBound();
        this.minList = new ArrayList<Integer>(Collections.nCopies(rcs.getNumPos(), 0));
    }

    public void setParallelism(Parallelism val) {
        if (val == null) {
            val = Parallelism.makeCpu(1);
        }
        this.parallelism = val;
        this.leafTasks = this.parallelism.makeTaskExecutor(1000);
        this.internalTasks = this.parallelism.makeTaskExecutor(1000);
        this.drillTasks = this.parallelism.makeTaskExecutor(1000);
        this.loopTasks = this.parallelism.makeTaskExecutor(1000);
        this.contexts.allocate(this.parallelism.getParallelism());
    }

    private void debugEpsilon(double curEpsilon) {
        if (this.debug && curEpsilon < this.epsilonBound) {
            System.err.println("Epsilon just got bigger.");
        }
    }

    private boolean shouldMinimize(MARKStarNode.Node node) {
        return node.getLevel() == this.RCs.getNumPos() && !node.isMinimized();
    }

    private void recordCorrection(double lowerBound, double correction) {
        BigDecimal upper = this.bc.calc(lowerBound);
        BigDecimal corrected = this.bc.calc(lowerBound + correction);
        this.cumulativeZCorrection = this.cumulativeZCorrection.add(upper.subtract(corrected));
    }

    private void recordReduction(double score, double energy) {
        BigDecimal scoreWeight = this.bc.calc(score);
        BigDecimal energyWeight = this.bc.calc(energy);
        this.ZReductionFromMin = this.ZReductionFromMin.add(scoreWeight.subtract(energyWeight));
    }

    private void debugBreakOnConf(int[] conf) {
        int[] confOfInterest = new int[]{1, 7, 5, 9, 2, 27, 3, 7, 3, 10, 3};
        if (conf.length != confOfInterest.length) {
            return;
        }
        boolean match = true;
        for (int i = 0; i < confOfInterest.length; ++i) {
            if (conf[i] == confOfInterest[i]) continue;
            match = false;
            break;
        }
        if (match) {
            System.out.println("Matched " + SimpleConfSpace.formatConfRCs(conf));
        }
    }

    private void runUntilNonZero() {
        System.out.println("Running until leaf is found...");
        double bestConfUpper = Double.POSITIVE_INFINITY;
        ArrayList<MARKStarNode> newNodes = new ArrayList<MARKStarNode>();
        ArrayList leafNodes = new ArrayList();
        boolean numNodes = false;
        Stopwatch leafLoop = new Stopwatch().start();
        Stopwatch overallLoop = new Stopwatch().start();
        this.boundLowestBoundConfUnderNode(this.rootNode, newNodes);
        this.queue.addAll(newNodes);
        newNodes.clear();
        System.out.println("Found a leaf!");
        this.nonZeroLower = true;
    }

    private void tightenBoundRigid() {
        Stopwatch loopWatch = new Stopwatch();
        MARKStarNode curNode = this.queue.poll();
        if (curNode.getConfSearchNode().getLevel() < this.RCs.getNumPos()) {
            loopWatch.start();
            this.processPartialConfNodeRigid(curNode, curNode.getConfSearchNode());
            this.internalTasks.waitForFinish();
            loopWatch.stop();
            loopWatch.reset();
        }
        this.updateBound();
        if (this.epsilonBound <= this.targetEpsilon) {
            return;
        }
    }

    private void processPartialConfNodeRigid(MARKStarNode curNode, MARKStarNode.Node node) {
        node.index(this.confIndex);
        int nextPos = this.order.getNextPos(this.confIndex, this.RCs);
        assert (!this.confIndex.isDefined(nextPos));
        assert (this.confIndex.isUndefined(nextPos));
        ArrayList children = new ArrayList();
        for (int nextRc : this.RCs.get(nextPos)) {
            if (this.hasPrunedPair(this.confIndex, nextPos, nextRc) || this.pruner != null && this.pruner.isPruned(node, nextPos, nextRc)) continue;
            this.internalTasks.submit(() -> {
                try (ObjectPool.Checkout<ScoreContext> checkout = this.contexts.autoCheckout();){
                    double diff;
                    Stopwatch partialTime = new Stopwatch().start();
                    ScoreContext context = checkout.get();
                    node.index(context.index);
                    MARKStarNode.Node child = node.assign(nextPos, nextRc);
                    double rigiddiff = diff = context.gscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    double hdiff = context.hscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    double maxhdiff = -context.negatedhscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    child.gscore = diff;
                    child.rigidScore = rigiddiff = rigiddiff - node.gscore + node.rigidScore;
                    double confLowerBound = child.gscore + hdiff;
                    double confUpperbound = rigiddiff + maxhdiff;
                    child.computeNumConformations(this.RCs);
                    child.setBoundsFromConfLowerAndUpper(confLowerBound, confUpperbound);
                    this.progress.reportInternalNode(child.level, child.gscore, child.getHScore(), this.queue.size(), children.size(), this.epsilonBound);
                    partialTime.stop();
                    this.loopPartialTime = (long)((double)this.loopPartialTime + partialTime.getTimeS());
                    MARKStarNode.Node node2 = child;
                    return node2;
                }
            }, child -> {
                MARKStarNode MARKStarNodeChild;
                if (Double.isNaN(child.rigidScore)) {
                    System.out.println("Huh!?");
                }
                if ((MARKStarNodeChild = curNode.makeChild((MARKStarNode.Node)child)).getConfSearchNode().getConfLowerBound() < 0.0) {
                    children.add(MARKStarNodeChild);
                }
                if (!child.isMinimized()) {
                    this.queue.add(MARKStarNodeChild);
                } else {
                    MARKStarNodeChild.computeEpsilonErrorBounds();
                }
                curNode.markUpdated();
            });
        }
    }

    private void debugHeap(Queue<MARKStarNode> queue) {
        int maxNodes = 10;
        System.out.println("Node heap:");
        ArrayList<MARKStarNode> nodes = new ArrayList<MARKStarNode>();
        while (!queue.isEmpty() && nodes.size() < 10) {
            MARKStarNode node = queue.poll();
            System.out.println(node.getConfSearchNode());
            nodes.add(node);
        }
        queue.addAll(nodes);
        this.rootNode.printTree();
    }

    boolean isStable(BigDecimal stabilityThreshold) {
        return this.numConfsEnergied <= 0 || stabilityThreshold == null || MathTools.isGreaterThanOrEqual(this.rootNode.getUpperBound(), stabilityThreshold);
    }

    private MARKStarNode drillDown(List<MARKStarNode> newNodes, MARKStarNode curNode, MARKStarNode.Node node) {
        try (ObjectPool.Checkout<ScoreContext> checkout = this.contexts.autoCheckout();){
            ScoreContext context = checkout.get();
            ConfIndex confIndex = context.index;
            node.index(confIndex);
            int nextPos = this.order.getNextPos(confIndex, this.RCs);
            assert (!confIndex.isDefined(nextPos));
            assert (confIndex.isUndefined(nextPos));
            ArrayList<MARKStarNode> children = new ArrayList<MARKStarNode>();
            double bestChildLower = Double.POSITIVE_INFINITY;
            MARKStarNode bestChild = null;
            for (int nextRc : this.RCs.get(nextPos)) {
                if (this.hasPrunedPair(confIndex, nextPos, nextRc) || this.pruner != null && this.pruner.isPruned(node, nextPos, nextRc)) continue;
                Stopwatch partialTime = new Stopwatch().start();
                MARKStarNode.Node child = node.assign(nextPos, nextRc);
                double confLowerBound = Double.POSITIVE_INFINITY;
                if (child.getLevel() < this.RCs.getNumPos()) {
                    double diff;
                    double rigiddiff = diff = context.gscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    double hdiff = context.hscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    double maxhdiff = -context.negatedhscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    child.gscore = diff;
                    child.rigidScore = rigiddiff = rigiddiff - node.gscore + node.rigidScore;
                    confLowerBound = child.gscore + hdiff;
                    double confUpperbound = rigiddiff + maxhdiff;
                    child.computeNumConformations(this.RCs);
                    child.setBoundsFromConfLowerAndUpper(confLowerBound, confUpperbound);
                    this.progress.reportInternalNode(child.level, child.gscore, child.getHScore(), this.queue.size(), children.size(), this.epsilonBound);
                }
                if (child.getLevel() == this.RCs.getNumPos()) {
                    double confRigid = context.rigidscorer.calcDifferential(context.index, this.RCs, nextPos, nextRc);
                    confRigid = confRigid - node.gscore + node.rigidScore;
                    child.computeNumConformations(this.RCs);
                    child.setBoundsFromConfLowerAndUpper(confRigid, confRigid);
                    child.gscore = child.getConfLowerBound();
                    confLowerBound = confRigid;
                    child.rigidScore = confRigid;
                    ++this.numConfsScored;
                    this.progress.reportLeafNode(child.gscore, this.queue.size(), this.epsilonBound);
                }
                partialTime.stop();
                this.loopPartialTime = (long)((double)this.loopPartialTime + partialTime.getTimeS());
                if (Double.isNaN(child.rigidScore)) {
                    System.out.println("Huh!?");
                }
                MARKStarNode MARKStarNodeChild = curNode.makeChild(child);
                MARKStarNodeChild.markUpdated();
                if (confLowerBound < bestChildLower) {
                    bestChild = MARKStarNodeChild;
                }
                if (MARKStarNodeChild.getConfSearchNode().getConfLowerBound() < 0.0) {
                    children.add(MARKStarNodeChild);
                }
                newNodes.add(MARKStarNodeChild);
            }
            int[] nArray = bestChild;
            return nArray;
        }
    }

    private void boundLowestBoundConfUnderNode(MARKStarNode startNode, List<MARKStarNode> generatedNodes) {
        Comparator<MARKStarNode> confBoundComparator = Comparator.comparingDouble(o -> o.getConfSearchNode().getConfLowerBound());
        PriorityQueue<MARKStarNode> drillQueue = new PriorityQueue<MARKStarNode>(confBoundComparator);
        drillQueue.add(startNode);
        ArrayList<MARKStarNode> newNodes = new ArrayList<MARKStarNode>();
        int numNodes = 0;
        Stopwatch leafLoop = new Stopwatch().start();
        Stopwatch overallLoop = new Stopwatch().start();
        while (!drillQueue.isEmpty()) {
            ++numNodes;
            MARKStarNode curNode = drillQueue.poll();
            MARKStarNode.Node node = curNode.getConfSearchNode();
            ConfIndex index = new ConfIndex(this.RCs.getNumPos());
            node.index(index);
            if (node.getLevel() < this.RCs.getNumPos()) {
                MARKStarNode nextNode = this.drillDown(newNodes, curNode, node);
                newNodes.remove(nextNode);
                drillQueue.add(nextNode);
            } else {
                newNodes.add(curNode);
            }
            if (!(leafLoop.getTimeS() > 10.0)) continue;
            leafLoop.stop();
            leafLoop.reset();
            leafLoop.start();
            System.out.println(String.format("Processed %d, %s so far. Bounds are now [%12.6e,%12.6e]", numNodes, overallLoop.getTime(2), this.rootNode.getLowerBound(), this.rootNode.getUpperBound()));
        }
        generatedNodes.addAll(newNodes);
    }

    private void checkBounds(double lower, double upper) {
        if (upper < lower && upper - lower > 1.0E-5 && upper < 10.0) {
            this.debugPrint("Bounds incorrect.");
        }
    }

    private void updateBound() {
        double curEpsilon = this.epsilonBound;
        Stopwatch time = new Stopwatch().start();
        this.epsilonBound = this.rootNode.computeEpsilonErrorBounds();
        time.stop();
        this.debugEpsilon(curEpsilon);
    }

    private boolean hasPrunedPair(ConfIndex confIndex, int nextPos, int nextRc) {
        PruningMatrix pmat = this.RCs.getPruneMat();
        if (pmat == null) {
            return false;
        }
        for (int i = 0; i < confIndex.numDefined; ++i) {
            int pos = confIndex.definedPos[i];
            int rc = confIndex.definedRCs[i];
            assert (pos != nextPos || rc != nextRc);
            if (!pmat.getPairwise(pos, rc, nextPos, nextRc).booleanValue()) continue;
            return true;
        }
        return false;
    }

    public static class Values
    extends PartitionFunction.Values {
        public Values() {
            this.pstar = MathTools.BigPositiveInfinity;
        }

        @Override
        public BigDecimal calcUpperBound() {
            return this.pstar;
        }

        @Override
        public BigDecimal calcLowerBound() {
            return this.qstar;
        }

        @Override
        public double getEffectiveEpsilon() {
            return MathTools.bigDivide(this.pstar.subtract(this.qstar), this.pstar, PartitionFunction.decimalPrecision).doubleValue();
        }
    }

    private static class ScoreContext {
        public ConfIndex index;
        public AStarScorer gscorer;
        public AStarScorer hscorer;
        public AStarScorer negatedhscorer;
        public AStarScorer rigidscorer;
        public ConfEnergyCalculator ecalc;

        private ScoreContext() {
        }
    }
}

