/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.astar.conf;

import edu.duke.cs.osprey.astar.conf.ConfAStarNode;
import edu.duke.cs.osprey.astar.conf.ConfIndex;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.astar.conf.scoring.AStarScorer;
import edu.duke.cs.osprey.astar.conf.scoring.PairwiseGScorer;
import edu.duke.cs.osprey.astar.conf.scoring.TraditionalPairwiseHScorer;
import edu.duke.cs.osprey.confspace.Conf;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.NegatedEnergyMatrix;
import edu.duke.cs.osprey.tools.TimeFormatter;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

public class ConfRanker {
    private static int Unassigned = -1;
    public final SimpleConfSpace confSpace;
    public final EnergyMatrix emat;
    public final RCs rcs;
    public final ScorerFactory gscorerFactory;
    public final ScorerFactory hscorerFactory;
    public final boolean reportProgress;
    private final AStarScorer gscorer;
    private final AStarScorer hscorer;
    private final AStarScorer negatedHScorer;
    private final ConfIndex confIndex;
    private final Node rootNode;

    private ConfRanker(SimpleConfSpace confSpace, EnergyMatrix emat, RCs rcs, ScorerFactory gscorerFactory, ScorerFactory hscorerFactory, boolean reportProgress) {
        this.confSpace = confSpace;
        this.rcs = rcs;
        this.emat = emat;
        this.gscorerFactory = gscorerFactory;
        this.hscorerFactory = hscorerFactory;
        this.reportProgress = reportProgress;
        this.gscorer = gscorerFactory.make(emat);
        this.hscorer = hscorerFactory.make(emat);
        this.negatedHScorer = hscorerFactory.make(new NegatedEnergyMatrix(confSpace, emat));
        this.confIndex = new ConfIndex(confSpace.positions.size());
        this.rootNode = new Node(confSpace.positions.size());
        this.rootNode.index(this.confIndex);
        this.rootNode.gscore = this.gscorer.calc(this.confIndex, rcs);
        this.rootNode.minHScore = this.hscorer.calc(this.confIndex, rcs);
        this.rootNode.maxHScore = -this.negatedHScorer.calc(this.confIndex, rcs);
    }

    public BigInteger getNumConfsAtMost(double queryScore) {
        Progress progress2 = new Progress(this.rcs.getNumConformations());
        this.numConfsAtMost(this.rootNode, queryScore, progress2);
        return progress2.below;
    }

    private void numConfsAtMost(Node node, double queryScore, Progress progress2) {
        node.index(this.confIndex);
        assert (this.confIndex.numUndefined > 0);
        if (this.confIndex.numUndefined == 1) {
            this.countLeaves(queryScore, progress2);
        } else {
            this.countBranches(node, queryScore, progress2);
        }
    }

    private void countLeaves(double queryScore, Progress progress2) {
        int pos = this.confIndex.undefinedPos[0];
        for (int rc : this.rcs.get(pos)) {
            double score = this.gscorer.calcDifferential(this.confIndex, this.rcs, pos, rc);
            if (score <= queryScore) {
                progress2.incrementBelow();
                continue;
            }
            progress2.incrementAbove();
        }
        progress2.writeReportIfNeeded();
    }

    private void countBranches(Node node, double queryScore, Progress progress2) {
        ArrayList<Node> childNodes = new ArrayList<Node>();
        double bestPosScore = Double.NEGATIVE_INFINITY;
        int bestPos = -1;
        for (int i = 0; i < this.confIndex.numUndefined; ++i) {
            int pos = this.confIndex.undefinedPos[i];
            int[] posRCs = this.rcs.get(pos);
            int numSubTreesPruned = 0;
            for (int rc : posRCs) {
                Node childNode = node.assign(pos, rc);
                childNode.gscore = this.gscorer.calcDifferential(this.confIndex, this.rcs, pos, rc);
                childNode.minHScore = this.hscorer.calcDifferential(this.confIndex, this.rcs, pos, rc);
                if (childNode.getMinScore() > queryScore) {
                    ++numSubTreesPruned;
                } else {
                    childNode.maxHScore = -this.negatedHScorer.calcDifferential(this.confIndex, this.rcs, pos, rc);
                    if (childNode.getMaxScore() <= queryScore) {
                        ++numSubTreesPruned;
                    }
                }
                childNodes.add(childNode);
            }
            double posScore = (double)numSubTreesPruned / (double)posRCs.length;
            if (!(posScore > bestPosScore)) continue;
            bestPosScore = posScore;
            bestPos = pos;
        }
        assert (bestPos >= 0);
        Iterator iter = childNodes.iterator();
        while (iter.hasNext()) {
            Node childNode = (Node)iter.next();
            if (childNode.pos != bestPos) {
                iter.remove();
                continue;
            }
            if (childNode.getMinScore() > queryScore) {
                progress2.incrementAbove(childNode.getNumConformations(this.rcs));
                iter.remove();
                continue;
            }
            if (!(childNode.getMaxScore() <= queryScore)) continue;
            progress2.incrementBelow(childNode.getNumConformations(this.rcs));
            iter.remove();
        }
        progress2.writeReportIfNeeded();
        for (Node childNode : childNodes) {
            this.numConfsAtMost(childNode, queryScore, progress2);
        }
    }

    public static interface ScorerFactory {
        public AStarScorer make(EnergyMatrix var1);
    }

    private static class Node
    implements ConfAStarNode {
        public double gscore = Double.NaN;
        public double minHScore = Double.NaN;
        public double maxHScore = Double.NaN;
        public int[] assignments;
        public int pos = Unassigned;
        public int rc = Unassigned;

        public Node(int size) {
            this.assignments = new int[size];
            Arrays.fill(this.assignments, Unassigned);
        }

        @Override
        public Node assign(int pos, int rc) {
            Node node = new Node(this.assignments.length);
            node.pos = pos;
            node.rc = rc;
            System.arraycopy(this.assignments, 0, node.assignments, 0, this.assignments.length);
            node.assignments[pos] = rc;
            return node;
        }

        @Override
        public double getGScore() {
            return this.gscore;
        }

        @Override
        public void setGScore(double val) {
            this.gscore = val;
        }

        public double getMinScore() {
            return this.gscore + this.minHScore;
        }

        public double getMaxScore() {
            return this.gscore + this.maxHScore;
        }

        @Override
        public double getHScore() {
            throw new UnsupportedOperationException();
        }

        @Override
        public void setHScore(double val) {
            throw new UnsupportedOperationException();
        }

        @Override
        public int getLevel() {
            throw new UnsupportedOperationException();
        }

        @Override
        public void getConf(int[] conf) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void index(ConfIndex confIndex) {
            Conf.index(this.assignments, confIndex);
            confIndex.node = this;
        }

        public BigInteger getNumConformations(RCs rcs) {
            BigInteger numConfs = BigInteger.ONE;
            for (int pos = 0; pos < this.assignments.length; ++pos) {
                if (this.assignments[pos] != Unassigned) continue;
                numConfs = numConfs.multiply(BigInteger.valueOf(rcs.getNum(pos)));
            }
            return numConfs;
        }
    }

    private class Progress {
        public final BigInteger total;
        public BigInteger below = BigInteger.ZERO;
        public BigInteger above = BigInteger.ZERO;
        public final long startTimeNs = System.nanoTime();
        public long reportIntervalMs = 5000L;
        private long lastReportTimeMs = 0L;

        public Progress(BigInteger total) {
            this.total = total;
        }

        public void incrementBelow() {
            this.incrementBelow(BigInteger.ONE);
        }

        public void incrementAbove() {
            this.incrementAbove(BigInteger.ONE);
        }

        public void incrementBelow(BigInteger val) {
            this.below = this.below.add(val);
        }

        public void incrementAbove(BigInteger val) {
            this.above = this.above.add(val);
        }

        public void writeReportIfNeeded() {
            long timeMs;
            if (ConfRanker.this.reportProgress && (timeMs = System.currentTimeMillis()) - this.lastReportTimeMs >= this.reportIntervalMs) {
                this.lastReportTimeMs = timeMs;
                System.out.println(this.getReport());
            }
        }

        public String getReport() {
            return String.format("progress: [%e,%e] of %e  (%.6f%%)   %s", this.below.doubleValue(), this.total.doubleValue() - this.above.doubleValue(), this.total.doubleValue(), this.below.add(this.above).doubleValue() / this.total.doubleValue() * 100.0, TimeFormatter.format(System.nanoTime() - this.startTimeNs, 2));
        }
    }

    public static class Builder {
        private final SimpleConfSpace confSpace;
        private final EnergyMatrix emat;
        private RCs rcs = null;
        private ScorerFactory gscorerFactory = null;
        private ScorerFactory hscorerFactory = null;
        private boolean reportProgress = false;

        public Builder(SimpleConfSpace confSpace, EnergyMatrix emat) {
            this.confSpace = confSpace;
            this.emat = emat;
        }

        public Builder setRCs(RCs val) {
            this.rcs = val;
            return this;
        }

        public Builder setGScorerFactory(ScorerFactory val) {
            this.gscorerFactory = val;
            return this;
        }

        public Builder setHScorerFactory(ScorerFactory val) {
            this.hscorerFactory = val;
            return this;
        }

        public Builder setReportProgress(boolean val) {
            this.reportProgress = val;
            return this;
        }

        public ConfRanker build() {
            if (this.rcs == null) {
                this.rcs = new RCs(this.confSpace);
            }
            if (this.gscorerFactory == null) {
                this.gscorerFactory = emat -> new PairwiseGScorer(emat);
            }
            if (this.hscorerFactory == null) {
                this.hscorerFactory = emat -> new TraditionalPairwiseHScorer(emat, this.rcs);
            }
            return new ConfRanker(this.confSpace, this.emat, this.rcs, this.gscorerFactory, this.hscorerFactory, this.reportProgress);
        }
    }
}

