/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.markstar.framework;

import edu.duke.cs.osprey.astar.conf.ConfIndex;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.markstar.framework.MARKStarBound;
import edu.duke.cs.osprey.markstar.framework.MARKStarNode;
import edu.duke.cs.osprey.parallelism.Parallelism;
import edu.duke.cs.osprey.tools.MathTools;
import edu.duke.cs.osprey.tools.Stopwatch;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;

public class MARKStarBoundFastQueues
extends MARKStarBound {
    public String stateName = String.format("%4f", Math.random());
    private Queue<MARKStarNode> leafQueue = new PriorityQueue<MARKStarNode>();
    private Queue<MARKStarNode> internalQueue = new PriorityQueue<MARKStarNode>();

    public MARKStarBoundFastQueues(SimpleConfSpace confSpace, EnergyMatrix rigidEmat, EnergyMatrix minimizingEmat, ConfEnergyCalculator minimizingConfEcalc, RCs rcs, Parallelism parallelism) {
        super(confSpace, rigidEmat, minimizingEmat, minimizingConfEcalc, rcs, parallelism);
    }

    @Override
    protected void tightenBoundInPhases() {
        System.out.println(String.format("Current overall error bound: %12.10f, spread of [%12.6e, %12.6e]", this.epsilonBound, this.rootNode.getLowerBound(), this.rootNode.getUpperBound()));
        ArrayList<MARKStarNode> internalNodes = new ArrayList<MARKStarNode>();
        ArrayList<MARKStarNode> leafNodes = new ArrayList<MARKStarNode>();
        List<MARKStarNode> newNodes = Collections.synchronizedList(new ArrayList());
        BigDecimal internalZ = BigDecimal.ONE;
        BigDecimal leafZ = BigDecimal.ONE;
        int numNodes = 0;
        Stopwatch loopWatch = new Stopwatch();
        loopWatch.start();
        Stopwatch internalTime = new Stopwatch();
        Stopwatch leafTime = new Stopwatch();
        double leafTimeSum = 0.0;
        double internalTimeSum = 0.0;
        BigDecimal[] ZSums = new BigDecimal[]{internalZ, leafZ};
        this.populateQueues(this.queue, internalNodes, leafNodes, internalZ, leafZ, ZSums);
        this.updateBound();
        this.debugPrint(String.format("After corrections, bounds are now [%12.6e,%12.6e]", this.rootNode.getLowerBound(), this.rootNode.getUpperBound()));
        internalZ = ZSums[0];
        leafZ = ZSums[1];
        System.out.println(String.format("Z Comparison: %12.6e, %12.6e", internalZ, leafZ));
        if (MathTools.isLessThan(internalZ, leafZ)) {
            numNodes = leafNodes.size();
            System.out.println("Processing " + numNodes + " leaf nodes...");
            leafTime.reset();
            leafTime.start();
            for (MARKStarNode leafNode : leafNodes) {
                this.processFullConfNode(newNodes, leafNode, leafNode.getConfSearchNode());
                leafNode.markUpdated();
                this.debugPrint("Processing Node: " + leafNode.getConfSearchNode().toString());
            }
            loopTasks.waitForFinish();
            leafTime.stop();
            this.leafTimeAverage = leafTime.getTimeS();
            System.out.println("Processed " + numNodes + " leaves in " + this.leafTimeAverage + " seconds.");
            if (this.maxMinimizations < this.parallelism.numThreads) {
                ++this.maxMinimizations;
            }
            this.internalQueue.addAll(internalNodes);
        } else {
            numNodes = internalNodes.size();
            System.out.println("Processing " + numNodes + " internal nodes...");
            internalTime.reset();
            internalTime.start();
            for (MARKStarNode internalNode : internalNodes) {
                if (!MathTools.isGreaterThan(internalNode.getLowerBound(), BigDecimal.ONE) && MathTools.isGreaterThan(MathTools.bigDivide(internalNode.getUpperBound(), this.rootNode.getUpperBound(), PartitionFunction.decimalPrecision), new BigDecimal(1.0 - this.targetEpsilon))) {
                    loopTasks.submit(() -> {
                        this.boundLowestBoundConfUnderNode(internalNode, newNodes);
                        return null;
                    }, ignored -> {});
                } else {
                    this.processPartialConfNode(newNodes, internalNode, internalNode.getConfSearchNode());
                }
                internalNode.markUpdated();
            }
            loopTasks.waitForFinish();
            internalTime.stop();
            internalTimeSum = internalTime.getTimeS();
            this.internalTimeAverage = internalTimeSum / (double)Math.max(1, internalNodes.size());
            this.debugPrint("Internal node time :" + internalTimeSum + ", average " + this.internalTimeAverage);
            this.numInternalNodesProcessed += internalNodes.size();
            this.leafQueue.addAll(leafNodes);
        }
        if (this.epsilonBound <= this.targetEpsilon) {
            return;
        }
        this.loopCleanup(newNodes, loopWatch, numNodes);
    }

    @Override
    protected void populateQueues(Queue<MARKStarNode> queue, List<MARKStarNode> internalNodes, List<MARKStarNode> leafNodes, BigDecimal internalZ, BigDecimal leafZ, BigDecimal[] ZSums) {
        ArrayList<MARKStarNode> leftoverLeaves = new ArrayList<MARKStarNode>();
        int maxNodes = 1000;
        if (this.leafTimeAverage > 0.0) {
            maxNodes = Math.max(maxNodes, (int)Math.floor(0.1 * this.leafTimeAverage / this.internalTimeAverage));
        }
        while (!(queue.isEmpty() || this.internalQueue.size() >= maxNodes && this.leafQueue.size() >= this.maxMinimizations)) {
            MARKStarNode curNode = queue.poll();
            MARKStarNode.Node node = curNode.getConfSearchNode();
            ConfIndex index = new ConfIndex(this.RCs.getNumPos());
            node.index(index);
            double correctgscore = this.correctionMatrix.confE(node.assignments);
            double hscore = node.getConfLowerBound() - node.gscore;
            double confCorrection = Math.min(correctgscore, node.rigidScore) + hscore;
            if (!node.isMinimized() && node.getConfLowerBound() < confCorrection && node.getConfLowerBound() - confCorrection > 1.0E-5) {
                if (confCorrection < node.getConfLowerBound()) {
                    System.out.println("huh!?");
                }
                System.out.println("Correction from " + String.valueOf(this.correctionMatrix.sourceECalc) + ":" + node.gscore + "->" + correctgscore);
                this.recordCorrection(node.getConfLowerBound(), correctgscore - node.gscore);
                node.gscore = correctgscore;
                if (confCorrection > node.rigidScore) {
                    System.out.println("Overcorrected" + SimpleConfSpace.formatConfRCs(node.assignments) + ": " + confCorrection + " > " + node.rigidScore);
                    node.gscore = node.rigidScore;
                    confCorrection = node.rigidScore + hscore;
                }
                node.setBoundsFromConfLowerAndUpper(confCorrection, node.getConfUpperBound());
                curNode.markUpdated();
                leftoverLeaves.add(curNode);
                continue;
            }
            if (node.getLevel() < this.RCs.getNumPos()) {
                this.internalQueue.add(curNode);
                continue;
            }
            if (!this.shouldMinimize(node) || this.correctedNode(leftoverLeaves, curNode, node)) continue;
            this.leafQueue.add(curNode);
        }
        ZSums[0] = this.fillListFromQueue(internalNodes, this.internalQueue, maxNodes);
        ZSums[1] = this.fillListFromQueue(leafNodes, this.leafQueue, this.maxMinimizations);
        queue.addAll(leftoverLeaves);
    }

    private BigDecimal fillListFromQueue(List<MARKStarNode> list, Queue<MARKStarNode> queue, int max) {
        BigDecimal sum = BigDecimal.ZERO;
        ArrayList<MARKStarNode> leftovers = new ArrayList<MARKStarNode>();
        while (!queue.isEmpty() && list.size() < max) {
            MARKStarNode curNode = queue.poll();
            if (this.correctedNode(leftovers, curNode, curNode.getConfSearchNode())) continue;
            BigDecimal diff = curNode.getUpperBound().subtract(curNode.getLowerBound());
            sum = sum.add(diff);
            list.add(curNode);
        }
        queue.addAll(leftovers);
        return sum;
    }
}

