/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.kstar;

import edu.duke.cs.osprey.kstar.KSAbstract;
import edu.duke.cs.osprey.kstar.KSAllowedSeqs;
import edu.duke.cs.osprey.kstar.KSCalc;
import edu.duke.cs.osprey.kstar.impl.KSImplKAStar;
import edu.duke.cs.osprey.kstar.pfunc.PFAbstract;
import edu.duke.cs.osprey.kstar.pfunc.impl.PFTraditional;
import edu.duke.cs.osprey.kstar.pfunc.impl.PFUB;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.IntStream;

public class KAStarNode {
    private static KSCalc wt;
    private static HashMap<Integer, KSAllowedSeqs> strand2AllowedSeqs;
    private static KSAbstract ksObj;
    private static int numCreated;
    private static int numExpanded;
    private static int numPruned;
    public static int numLeavesCreated;
    public static int numLeavesCompleted;
    public KSCalc ub;
    public KSCalc lb;
    protected double parentlbScore;
    protected double lbScore;
    protected double ubScore;
    protected boolean scoreNeedsRefinement;
    public static Comparator<KAStarNode> KUStarNodeComparator;

    public static void init(KSAbstract ksObj, HashMap<Integer, KSAllowedSeqs> strand2AllowedSeqs, KSCalc wt) {
        KAStarNode.ksObj = ksObj;
        KAStarNode.strand2AllowedSeqs = strand2AllowedSeqs;
        KAStarNode.wt = wt;
    }

    public KAStarNode(KSCalc lb, KSCalc ub, boolean scoreNeedsRefinement) {
        this.lb = lb;
        this.ub = ub;
        this.parentlbScore = Double.NEGATIVE_INFINITY;
        this.lbScore = Double.NEGATIVE_INFINITY;
        this.ubScore = Double.POSITIVE_INFINITY;
        this.scoreNeedsRefinement = scoreNeedsRefinement;
    }

    public static int getNumLeavesCreated() {
        return numLeavesCreated;
    }

    public static int getNumLeavesCompleted() {
        return numLeavesCompleted;
    }

    public static int getNumCreated() {
        return numCreated;
    }

    public static int getNumExpanded() {
        return numExpanded;
    }

    public static int getNumPruned() {
        return numPruned;
    }

    public ArrayList<KAStarNode> expand() {
        ArrayList<KAStarNode> children = null;
        if (this.isFullyDefined() && !this.scoreNeedsRefinement()) {
            if (!this.isFullyProcessed() && this.lb.canContinue()) {
                this.computeScore(this);
            }
            children = new ArrayList<KAStarNode>();
            children.add(this);
            this.pruneIncompatibleSuccessors(children, null);
            return children;
        }
        ++numExpanded;
        ArrayList<Integer> strands = new ArrayList<Integer>(Arrays.asList(1, 0, 2));
        ArrayList<Integer> nextDepths = new ArrayList<Integer>();
        for (int i = 0; i < strands.size(); ++i) {
            nextDepths.add(0);
        }
        ArrayList seqs = new ArrayList();
        for (int i = 0; i < strands.size(); ++i) {
            seqs.add(new ArrayList());
        }
        KSAllowedSeqs pSeqs = strand2AllowedSeqs.get(0);
        KSAllowedSeqs lSeqs = strand2AllowedSeqs.get(1);
        for (int strand : strands) {
            if (this.depth() != 0) {
                ArrayList<String> seq = this.lb.getPF(strand).getSequence();
                seqs.set(strand, seq);
                if (strand != 2) {
                    nextDepths.set(strand, Math.min(seq.size() + 1, strand2AllowedSeqs.get(strand).getStrandSubSeqsMaxDepth()));
                }
            } else {
                KSAllowedSeqs strandSeqs = strand2AllowedSeqs.get(strand);
                int depth = 1;
                if (strand != 2) {
                    strandSeqs.getStrandSubSeqsAtDepth(depth);
                    nextDepths.set(strand, depth);
                } else {
                    strandSeqs.getStrandSubSeqsAtDepth(depth, pSeqs, lSeqs);
                }
            }
            if (strand != 2) continue;
            nextDepths.set(strand, (Integer)nextDepths.get(0) + (Integer)nextDepths.get(1));
        }
        HashSet<ArrayList<String>> nextPSeqs = new HashSet<ArrayList<String>>(strand2AllowedSeqs.get(0).getStrandSubSeqsAtDepth((Integer)nextDepths.get(0)));
        HashSet<ArrayList<String>> nextLSeqs = new HashSet<ArrayList<String>>(strand2AllowedSeqs.get(1).getStrandSubSeqsAtDepth((Integer)nextDepths.get(1)));
        HashSet<ArrayList<String>> nextPLSeqs = new HashSet<ArrayList<String>>(strand2AllowedSeqs.get(2).getStrandSubSeqsAtDepth((Integer)nextDepths.get(2)));
        if (this.depth() != 0) {
            ArrayList currentLSeq = (ArrayList)seqs.get(1);
            Iterator<ArrayList<String>> iterator2 = nextLSeqs.iterator();
            while (iterator2.hasNext()) {
                ArrayList<String> subSeq = iterator2.next();
                if (subSeq.subList(0, currentLSeq.size()).equals(currentLSeq)) continue;
                iterator2.remove();
            }
            ArrayList currentPSeq = (ArrayList)seqs.get(0);
            Iterator<ArrayList<String>> iterator3 = nextPSeqs.iterator();
            while (iterator3.hasNext()) {
                ArrayList<String> subSeq = iterator3.next();
                if (subSeq.subList(0, currentPSeq.size()).equals(currentPSeq)) continue;
                iterator3.remove();
            }
        }
        if ((children = this.getChildren(nextPLSeqs, nextPSeqs, nextLSeqs)).size() > 0) {
            switch (KSImplKAStar.nodeExpansionMethod) {
                case "serial": {
                    this.computeScoresSerial(children);
                    break;
                }
                case "parallel2": {
                    this.computeScoresComplexParallel(children);
                    break;
                }
                default: {
                    this.computeScoresSimpleParallel(children);
                }
            }
            this.pruneIncompatibleSuccessors(children, nextDepths);
        }
        return children;
    }

    private void pruneIncompatibleSuccessors(ArrayList<KAStarNode> children, ArrayList<Integer> nextDepths) {
        Iterator<KAStarNode> iterator2 = children.iterator();
        while (iterator2.hasNext()) {
            KAStarNode child = iterator2.next();
            if (child.ub != null && child.ub.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE) {
                for (int strand : Arrays.asList(1, 0)) {
                    PFAbstract pf = child.ub.getPF(strand);
                    if (pf.getEpsilonStatus() == PFAbstract.EApproxReached.NOT_POSSIBLE && pf.getQStar().add(pf.getQPrime().add(pf.getPStar())).compareTo(BigDecimal.ZERO) > 0) {
                        System.out.println("ERROR: " + KSAbstract.list1D2String(pf.getSequence(), " ") + " " + pf.getFlexibility() + " is wrong!!!");
                    }
                    if (pf.getEpsilonStatus() != PFAbstract.EApproxReached.NOT_STABLE && pf.getNumUnPruned().compareTo(BigInteger.ZERO) != 0) continue;
                    ArrayList<String> seq = pf.getSequence();
                    for (int strand2 : Arrays.asList(strand, 2)) {
                        if (nextDepths == null) continue;
                        this.pruneSequences(seq, strand2, nextDepths.get(strand2));
                    }
                }
                iterator2.remove();
                continue;
            }
            if (child.lb.doingKAStar() || child.lb.canContinue()) continue;
            iterator2.remove();
        }
    }

    private void pruneSequences(ArrayList<String> seq, int strand, int depth) {
        HashSet<ArrayList<String>> set = null;
        int maxDepth = strand2AllowedSeqs.get(strand).getStrandSubSeqsMaxDepth();
        while (depth <= maxDepth) {
            set = strand2AllowedSeqs.get(strand).getStrandSubSeqsAtDepth(depth);
            int oldSize = set.size();
            KSAllowedSeqs.deleteFromSet(seq, set);
            numPruned += oldSize - set.size();
            ++depth;
        }
    }

    private void computeScore(KAStarNode child) {
        if (child.scoreNeedsRefinement()) {
            PFAbstract.suppressOutput = true;
            for (int strand : Arrays.asList(1, 0)) {
                if (child.ub.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE) continue;
                child.ub.runPF(child.ub.getPF(strand), wt.getPF(strand), true, true);
            }
            if (child.ub.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE) {
                PFAbstract.suppressOutput = false;
                return;
            }
            child.lb.run(wt, true, false);
            child.lbScore = -1.0 * child.lb.getKStarScoreLog10(true);
            this.checkConsistency(child);
            PFAbstract.suppressOutput = false;
        } else {
            boolean lbStabilityCheck = KSAbstract.doCheckPoint = KSImplKAStar.useTightBounds;
            if (!KSImplKAStar.useTightBounds) {
                for (int strand : Arrays.asList(1, 0)) {
                    if (child.ub.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE) continue;
                    child.ub.runPF(child.ub.getPF(strand), wt.getPF(strand), true, true);
                }
                if (child.ub.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE) {
                    KSAbstract.doCheckPoint = false;
                    return;
                }
            }
            child.lb.run(wt, false, lbStabilityCheck);
            if (!child.lb.doingKAStar() && !child.lb.canContinue()) {
                return;
            }
            if (child.lb.getEpsilonStatus() == PFAbstract.EApproxReached.TRUE) {
                numLeavesCompleted = ksObj.getNumSeqsCompleted(1);
            }
            child.lbScore = -1.0 * child.lb.getKStarScoreLog10(true);
            KSAbstract.doCheckPoint = false;
        }
    }

    protected void computeScoresSerial(ArrayList<KAStarNode> children) {
        if (children.size() == 0) {
            return;
        }
        for (KAStarNode child : children) {
            if (child.isFullyProcessed()) continue;
            this.computeScore(child);
        }
    }

    protected void computeScoresSimpleParallel(ArrayList<KAStarNode> children) {
        if (children.size() == 0) {
            return;
        }
        boolean parallel = false;
        if (this.canParallelize(children)) {
            parallel = true;
        }
        if (parallel) {
            children.parallelStream().forEach(child -> {
                if (!child.isFullyProcessed()) {
                    this.computeScore((KAStarNode)child);
                }
            });
        } else {
            this.computeScoresSerial(children);
        }
    }

    protected void computeScoresComplexParallel(ArrayList<KAStarNode> children) {
        if (children.size() == 0) {
            return;
        }
        if (children.get(0).scoreNeedsRefinement()) {
            PFAbstract.suppressOutput = true;
            ArrayList<KSCalc> calcs = new ArrayList<KSCalc>();
            ArrayList<CalcParams> params = new ArrayList<CalcParams>();
            for (int strand : Arrays.asList(1, 0)) {
                ArrayList<KSCalc> ans = this.getCalcs4Strand(children, false, strand);
                for (KSCalc calc2 : ans) {
                    calcs.add(calc2);
                    params.add(new CalcParams(this, strand, true, true));
                }
            }
            IntStream.range(0, calcs.size()).parallel().forEach(i -> {
                KSCalc calc2 = (KSCalc)calcs.get(i);
                CalcParams param = (CalcParams)params.get(i);
                calc2.runPF(calc2.getPF(param.strand), wt.getPF(param.strand), param.complete, param.stabilityCheck);
            });
            calcs.clear();
            params.clear();
            ArrayList<KAStarNode> children2 = new ArrayList<KAStarNode>();
            for (KAStarNode child : children) {
                if (child.ub.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE) continue;
                children2.add(child);
            }
            Iterator<Object> iterator2 = Arrays.asList(1, 0, 2).iterator();
            while (iterator2.hasNext()) {
                int strand = (Integer)iterator2.next();
                ArrayList<KSCalc> ans = this.getCalcs4Strand(children2, true, strand);
                for (KSCalc calc3 : ans) {
                    calcs.add(calc3);
                    params.add(new CalcParams(this, strand, true, false));
                }
            }
            IntStream.range(0, calcs.size()).parallel().forEach(i -> {
                KSCalc calc2 = (KSCalc)calcs.get(i);
                CalcParams param = (CalcParams)params.get(i);
                calc2.runPF(calc2.getPF(param.strand), wt.getPF(param.strand), param.complete, param.stabilityCheck);
            });
            for (KAStarNode child : children2) {
                child.lbScore = -1.0 * child.lb.getKStarScoreLog10(true);
                this.checkConsistency(child);
            }
            PFAbstract.suppressOutput = false;
        } else {
            this.computeScoresSerial(children);
        }
    }

    private ArrayList<KSCalc> getCalcs4Strand(ArrayList<KAStarNode> children, boolean lb, int strand) {
        HashMap<ArrayList<String>, KSCalc> seq2KSCalc = new HashMap<ArrayList<String>, KSCalc>();
        for (KAStarNode child : children) {
            KSCalc calc2 = lb ? child.lb : child.ub;
            PFAbstract pf = calc2.getPF(strand);
            seq2KSCalc.put(pf.getSequence(), calc2);
        }
        ArrayList<KSCalc> calcs = new ArrayList<KSCalc>(seq2KSCalc.values());
        return calcs;
    }

    private boolean canParallelize(ArrayList<KAStarNode> children) {
        if (children.size() < 2) {
            return false;
        }
        HashSet<String> setPL = new HashSet<String>(children.size());
        HashSet<String> setP = new HashSet<String>(children.size());
        HashSet<String> setL = new HashSet<String>(children.size());
        for (KAStarNode child : children) {
            String plSeq = KSAbstract.list1D2String(child.lb.getPF(2).getSequence(), " ");
            String pSeq = KSAbstract.list1D2String(child.lb.getPF(0).getSequence(), " ");
            String lSeq = KSAbstract.list1D2String(child.lb.getPF(1).getSequence(), " ");
            if (setPL.contains(plSeq)) {
                return false;
            }
            setPL.add(plSeq);
            if (setP.contains(pSeq)) {
                return false;
            }
            setP.add(pSeq);
            if (setL.contains(lSeq)) {
                return false;
            }
            setL.add(lSeq);
        }
        return true;
    }

    private ArrayList<KAStarNode> getChildren(HashSet<ArrayList<String>> nextPLSeqs, HashSet<ArrayList<String>> pSeqs, HashSet<ArrayList<String>> lSeqs) {
        ArrayList<ArrayList<String>> strandSeqs = new ArrayList<ArrayList<String>>(Arrays.asList(null, null, null));
        ArrayList<Boolean> lbContSCFlexVals = new ArrayList<Boolean>(Arrays.asList(false, false, true));
        ArrayList<String> lbPFImplVals = new ArrayList<String>(Arrays.asList(new PFTraditional().getImpl(), new PFTraditional().getImpl(), new PFUB().getImpl()));
        ArrayList<Boolean> ubContSCFlexVals = new ArrayList<Boolean>(Arrays.asList(true, true, false));
        ArrayList<String> ubPFImplVals = new ArrayList<String>(Arrays.asList(new PFUB().getImpl(), new PFUB().getImpl(), new PFTraditional().getImpl()));
        ArrayList<Boolean> tightContSCFlexVals = new ArrayList<Boolean>(Arrays.asList(true, true, true));
        ArrayList<String> tightPFImplVals = new ArrayList<String>(Arrays.asList(PFAbstract.getCFGImpl(), PFAbstract.getCFGImpl(), PFAbstract.getCFGImpl()));
        ArrayList<KAStarNode> ans = new ArrayList<KAStarNode>();
        for (ArrayList<String> pSeq : pSeqs) {
            for (ArrayList<String> lSeq : lSeqs) {
                ArrayList<String> putativeNextPLSeq = new ArrayList<String>();
                putativeNextPLSeq.addAll(pSeq);
                putativeNextPLSeq.addAll(lSeq);
                putativeNextPLSeq.trimToSize();
                if (!nextPLSeqs.contains(putativeNextPLSeq)) continue;
                ++numCreated;
                strandSeqs.set(2, putativeNextPLSeq);
                strandSeqs.set(0, pSeq);
                strandSeqs.set(1, lSeq);
                if (!this.isFullyDefined()) {
                    ConcurrentHashMap<Integer, PFAbstract> lbPFs = ksObj.createPFs4Seqs(strandSeqs, lbContSCFlexVals, lbPFImplVals);
                    ConcurrentHashMap<Integer, PFAbstract> ubPFs = ksObj.createPFs4Seqs(strandSeqs, ubContSCFlexVals, ubPFImplVals);
                    ans.add(new KAStarNode(new KSCalc(numCreated, lbPFs), new KSCalc(numCreated, ubPFs), this.childScoreNeedsRefinement(lbPFs)));
                    ans.get((int)(ans.size() - 1)).parentlbScore = this.lbScore;
                    continue;
                }
                if (KSImplKAStar.useTightBounds) {
                    ConcurrentHashMap<Integer, PFAbstract> tightPFs = ksObj.createPFs4Seqs(strandSeqs, tightContSCFlexVals, tightPFImplVals);
                    KSAllowedSeqs complexSeqs = KAStarNode.ksObj.strand2AllowedSeqs.get(2);
                    int seqID = complexSeqs.getPosOfSeq(tightPFs.get(2).getSequence());
                    ans.add(new KAStarNode(new KSCalc(seqID, tightPFs), null, false));
                    ans.get((int)(ans.size() - 1)).parentlbScore = this.lbScore;
                    numLeavesCreated = ksObj.getNumSeqsCreated(1);
                    continue;
                }
                throw new RuntimeException("ERROR: cannot expand a fully assigned node");
            }
        }
        return ans;
    }

    private int depth() {
        if (this.lb == null) {
            return 0;
        }
        return this.lb.getPF(2).getSequence().size();
    }

    public boolean scoreNeedsRefinement() {
        return this.scoreNeedsRefinement;
    }

    public double getParentLBScore() {
        return this.parentlbScore;
    }

    public double getLBScore() {
        return this.lbScore;
    }

    public double getUBScore() {
        if (this.ub == null) {
            return this.ubScore;
        }
        PFAbstract.suppressOutput = true;
        this.ub.runPF(this.ub.getPF(2), null, true, false);
        PFAbstract.suppressOutput = false;
        this.ubScore = -1.0 * this.ub.getKStarScoreLog10(true);
        return this.ubScore;
    }

    private boolean isFullyDefined() {
        int maxDepth = wt.getPF(2).getSequence().size();
        return this.depth() >= maxDepth;
    }

    public void checkConsistency(KAStarNode node) {
        double nodeLB = node.getLBScore();
        double parentLB = node.getParentLBScore();
        if (nodeLB == Double.NEGATIVE_INFINITY) {
            return;
        }
        if (parentLB > nodeLB) {
            throw new RuntimeException("ERROR: parentLB: " + parentLB + " must be <= nodeLB: " + nodeLB);
        }
    }

    private boolean childScoreNeedsRefinement(ConcurrentHashMap<Integer, PFAbstract> lbPFs) {
        PFAbstract pf = lbPFs.get(2);
        if (!pf.isFullyDefined()) {
            return true;
        }
        return KSImplKAStar.useTightBounds;
    }

    public boolean isFullyProcessed() {
        return !this.scoreNeedsRefinement() && this.lb.getEpsilonStatus() != PFAbstract.EApproxReached.FALSE;
    }

    static {
        numCreated = 0;
        numExpanded = 0;
        numPruned = 0;
        numLeavesCreated = 0;
        numLeavesCompleted = 0;
        KUStarNodeComparator = new Comparator<KAStarNode>(){

            @Override
            public int compare(KAStarNode lhs, KAStarNode rhs) {
                return lhs.lbScore >= rhs.lbScore ? 1 : -1;
            }
        };
    }

    private class CalcParams {
        int strand;
        boolean complete;
        boolean stabilityCheck;

        public CalcParams(KAStarNode kAStarNode, int strand, boolean complete, boolean stabilityCheck) {
            this.strand = strand;
            this.complete = complete;
            this.stabilityCheck = stabilityCheck;
        }
    }
}

