/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.ewakstar;

import edu.duke.cs.osprey.astar.conf.ConfAStarTree;
import edu.duke.cs.osprey.astar.conf.ConfSearchCache;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.confspace.ConfDB;
import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.confspace.SeqSpace;
import edu.duke.cs.osprey.confspace.Sequence;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.energy.EnergyCalculator;
import edu.duke.cs.osprey.ewakstar.EWAKStar;
import edu.duke.cs.osprey.ewakstar.EWAKStarGradientDescentPfunc;
import edu.duke.cs.osprey.ewakstar.EWAKStarPartitionFunction;
import edu.duke.cs.osprey.ewakstar.EWAKStarScore;
import edu.duke.cs.osprey.ewakstar.EWAKStarScoreWriter;
import edu.duke.cs.osprey.ewakstar.EwakstarDoer;
import edu.duke.cs.osprey.ewakstar.EwakstarLimitedSequenceTrie;
import edu.duke.cs.osprey.gmec.ConfAnalyzer;
import edu.duke.cs.osprey.kstar.pfunc.BoltzmannCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.kstar.pfunc.UpperBoundCalculator;
import edu.duke.cs.osprey.tools.BigMath;
import edu.duke.cs.osprey.tools.MathTools;
import java.io.File;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Function;

public class EWAKStarBBKStar {
    public final ConfSpaceInfo protein;
    public final ConfSpaceInfo ligand;
    public final ConfSpaceInfo complex;
    public final EWAKStar.Settings kstarSettings;
    public final Settings bbkstarSettings;
    private final Map<Sequence, EWAKStarPartitionFunction> proteinPfuncs;
    private final Map<Sequence, EWAKStarPartitionFunction> ligandPfuncs;
    private final Map<Sequence, EWAKStarPartitionFunction> complexPfuncs;
    private final ConfSearchCache confTrees;

    public EWAKStarBBKStar(EwakstarDoer.State P, EwakstarDoer.State L, EwakstarDoer.State PL, EWAKStar.Settings kstarSettings, Settings bbkstarSettings, Integer minNumConfTrees) {
        if (kstarSettings.useExternalMemory) {
            throw new IllegalArgumentException("BBK* is not compatible with external memory. Please switch to regular K* with external memory, or keep using BBK* and disable external memory.");
        }
        this.protein = new ConfSpaceInfo(this, P.confSpace, EWAKStar.ConfSpaceType.Protein);
        this.ligand = new ConfSpaceInfo(this, L.confSpace, EWAKStar.ConfSpaceType.Ligand);
        this.complex = new ConfSpaceInfo(this, PL.confSpace, EWAKStar.ConfSpaceType.Complex);
        this.kstarSettings = kstarSettings;
        this.bbkstarSettings = bbkstarSettings;
        this.confTrees = new ConfSearchCache(minNumConfTrees);
        this.proteinPfuncs = new HashMap<Sequence, EWAKStarPartitionFunction>();
        this.ligandPfuncs = new HashMap<Sequence, EWAKStarPartitionFunction>();
        this.complexPfuncs = new HashMap<Sequence, EWAKStarPartitionFunction>();
    }

    public Iterable<ConfSpaceInfo> confSpaceInfos() {
        return Arrays.asList(this.protein, this.ligand, this.complex);
    }

    public ConfSpaceInfo getConfSpaceInfo(SimpleConfSpace confSpace) {
        if (confSpace == this.protein.confSpace) {
            return this.protein;
        }
        if (confSpace == this.ligand.confSpace) {
            return this.ligand;
        }
        if (confSpace == this.complex.confSpace) {
            return this.complex;
        }
        throw new IllegalArgumentException("conf space does not match any known by this K* instance");
    }

    public List<Sequence> run() {
        ArrayList<Sequence> scoredSequences;
        block29: {
            this.protein.check();
            this.ligand.check();
            this.complex.check();
            this.proteinPfuncs.clear();
            this.ligandPfuncs.clear();
            this.complexPfuncs.clear();
            scoredSequences = new ArrayList<Sequence>();
            try (ConfDB.DBs confDBs = new ConfDB.DBs().add(this.protein.confSpace, this.protein.confDBFile).add(this.ligand.confSpace, this.ligand.confDBFile).add(this.complex.confSpace, this.complex.confDBFile);){
                PriorityQueue<Node> tree = new PriorityQueue<Node>();
                tree.add(new MultiSequenceNode(this.complex.confSpace.makeUnassignedSequence(), confDBs));
                System.out.println("Computing K* scores for the best sequences to within an energy window of " + this.kstarSettings.eW + " kcal, with a max of " + this.kstarSettings.maxPFConfs + " conformations, and an epsilon of " + this.kstarSettings.epsilon + "...");
                this.kstarSettings.scoreWriters.writeHeader();
                Boolean wtSeqFound = false;
                if (this.kstarSettings.wtBenchmark) {
                    while (!tree.isEmpty() && !wtSeqFound.booleanValue()) {
                        Node node = (Node)tree.poll();
                        if (node.isWTSeq) {
                            wtSeqFound = true;
                        }
                        if (node instanceof SingleSequenceNode) {
                            SingleSequenceNode ssnode = (SingleSequenceNode)node;
                            switch (ssnode.getStatus()) {
                                case Estimated: {
                                    this.reportSequence(ssnode, scoredSequences);
                                    ssnode.protein = null;
                                    ssnode.complex = null;
                                    ssnode.ligand = null;
                                    break;
                                }
                                case Estimating: {
                                    ssnode.estimateScore();
                                    if (ssnode.isUnboundUnstable) break;
                                    tree.add(ssnode);
                                    break;
                                }
                                case Blocked: {
                                    this.reportSequence(ssnode, scoredSequences);
                                }
                            }
                            continue;
                        }
                        if (!(node instanceof MultiSequenceNode)) continue;
                        MultiSequenceNode msnode = (MultiSequenceNode)node;
                        for (Node child : msnode.makeChildren()) {
                            child.estimateScore();
                            if (child.isUnboundUnstable) continue;
                            tree.add(child);
                        }
                    }
                } else {
                    while (!tree.isEmpty() && scoredSequences.size() < this.kstarSettings.numTopOverallSeqs) {
                        Node node = (Node)tree.poll();
                        if (node instanceof SingleSequenceNode) {
                            SingleSequenceNode ssnode = (SingleSequenceNode)node;
                            switch (ssnode.getStatus()) {
                                case Estimated: {
                                    this.reportSequence(ssnode, scoredSequences);
                                    ssnode.protein = null;
                                    ssnode.complex = null;
                                    ssnode.ligand = null;
                                    break;
                                }
                                case Estimating: {
                                    ssnode.estimateScore();
                                    if (ssnode.isUnboundUnstable) break;
                                    tree.add(ssnode);
                                    break;
                                }
                                case Blocked: {
                                    this.reportSequence(ssnode, scoredSequences);
                                }
                            }
                            continue;
                        }
                        if (!(node instanceof MultiSequenceNode)) continue;
                        MultiSequenceNode msnode = (MultiSequenceNode)node;
                        for (Node child : msnode.makeChildren()) {
                            child.estimateScore();
                            if (child.isUnboundUnstable) continue;
                            tree.add(child);
                        }
                    }
                }
                if (tree.isEmpty()) {
                    System.out.println("All " + scoredSequences.size() + " sequences calculated. EWAK* complete.");
                    break block29;
                }
                if (wtSeqFound.booleanValue() && scoredSequences.size() == 1) {
                    System.out.println("No K* scores found that are better than the wild-type sequence.");
                    break block29;
                }
                if (wtSeqFound.booleanValue()) {
                    System.out.println("Found K* score estimates for all " + scoredSequences.size() + " sequences with scores greater than that of the wild-type sequence.");
                    break block29;
                }
                if (scoredSequences.size() >= this.kstarSettings.numTopOverallSeqs) {
                    System.out.println("Found K* score estimates for top " + this.kstarSettings.numTopOverallSeqs + " sequences.");
                    break block29;
                }
                throw new Error("EWAK* ended, but the tree isn't empty and we didn't return all of the sequences. This is a bug.");
            }
        }
        return scoredSequences;
    }

    private void reportSequence(SingleSequenceNode ssnode, List<Sequence> scoredSequences) {
        EWAKStarScore kstarScore = ssnode.makeKStarScore();
        scoredSequences.add(ssnode.sequence);
        if (this.kstarSettings.printPDBs) {
            String seqDir;
            File directory;
            Iterator<EnergyCalculator.EnergiedParametricMolecule> econfs = ssnode.complex.getEpMols().iterator();
            HashMap<Double, ConfSearch.ScoredConf> sconfs = ssnode.complex.getSConfs();
            ConfAnalyzer analyzer = new ConfAnalyzer(this.complex.confEcalcMinimized);
            ConfAnalyzer.EnsembleAnalysis analysis = analyzer.analyzeEnsemble(sconfs, econfs, 10);
            String pdbString = "pdbs";
            File pdbDir = new File(pdbString);
            if (!pdbDir.exists()) {
                pdbDir.mkdir();
            }
            if (!(directory = new File(pdbString + "/" + (seqDir = ssnode.sequence.toString().replaceAll(" ", "_")))).exists()) {
                directory.mkdir();
            }
            analysis.writePdbs(pdbString + "/" + seqDir + "/conf.*.pdb");
            sconfs = null;
            analyzer = null;
            econfs = null;
        }
        this.kstarSettings.scoreWriters.writeScore(new EWAKStarScoreWriter.ScoreInfo(scoredSequences.size() - 1, 0, ssnode.sequence, this.complex.confSpace, kstarScore));
        kstarScore = null;
    }

    public class ConfSpaceInfo {
        public final SimpleConfSpace confSpace;
        public final EWAKStar.ConfSpaceType type;
        public final String id;
        public ConfEnergyCalculator confEcalcMinimized = null;
        public Function<RCs, ConfAStarTree> confTreeFactoryMinimized = null;
        public Function<RCs, ConfAStarTree> confTreeFactoryRigid = null;
        public File confDBFile = null;
        private BigDecimal stabilityThreshold = null;

        public ConfSpaceInfo(EWAKStarBBKStar this$0, SimpleConfSpace confSpace, EWAKStar.ConfSpaceType type) {
            this.confSpace = confSpace;
            this.type = type;
            this.id = type.name().toLowerCase();
        }

        private void check() {
            if (this.confEcalcMinimized == null) {
                throw new EWAKStar.InitException(this.type, "confEcalcMinimized");
            }
            if (this.confTreeFactoryMinimized == null) {
                throw new EWAKStar.InitException(this.type, "confTreeFactoryMinimized");
            }
            if (this.confTreeFactoryRigid == null) {
                throw new EWAKStar.InitException(this.type, "confTreeFactoryRigid");
            }
        }

        public void setConfDBFile(String path2) {
            this.confDBFile = new File(path2);
        }
    }

    public static class Settings {
        public final EwakstarLimitedSequenceTrie allowedSeqs;
        public final int numBestSequences;
        public final int numConfsPerBatch;

        public Settings(EwakstarLimitedSequenceTrie seqs, int numBestSequences, int numConfsPerBatch) {
            this.numBestSequences = numBestSequences;
            this.numConfsPerBatch = numConfsPerBatch;
            this.allowedSeqs = seqs;
        }

        public static class Builder {
            private int numBestSequences = 1;
            private EwakstarLimitedSequenceTrie allowedSeqs;
            private int numConfsPerBatch = 8;

            public Builder setAllowedSeqs(EwakstarLimitedSequenceTrie seqs) {
                this.allowedSeqs = seqs;
                return this;
            }

            public Builder setNumConfsPerBatch(int val) {
                this.numConfsPerBatch = val;
                return this;
            }

            public Builder setNumBestSequences(int val) {
                this.numBestSequences = val;
                return this;
            }

            public Settings build() {
                return new Settings(this.allowedSeqs, this.numBestSequences, this.numConfsPerBatch);
            }
        }
    }

    public class MultiSequenceNode
    extends Node {
        public MultiSequenceNode(Sequence sequence, ConfDB.DBs confdbs) {
            super(EWAKStarBBKStar.this, sequence, confdbs);
        }

        public List<Node> makeChildren() {
            ArrayList<Node> children = new ArrayList<Node>();
            List<SeqSpace.Position> positions = EWAKStarBBKStar.this.complex.confSpace.seqSpace.positions;
            SeqSpace.Position assignPos = positions.stream().filter(pos -> !this.sequence.isAssigned(pos.resNum)).findFirst().orElseThrow(() -> new IllegalStateException("no design positions left to choose"));
            Set<String> resTypes = assignPos.resTypes.size() == 1 ? new HashSet<String>(this.getResTypeList(assignPos.resTypes)) : this.filterOnPreviousSeqs();
            for (String resType : resTypes) {
                Sequence s = this.sequence.copy().set(assignPos, resType);
                if (s.isFullyAssigned()) {
                    children.add(new SingleSequenceNode(s, this.confDBs));
                    continue;
                }
                if (EWAKStarBBKStar.this.kstarSettings.useExact && s.countMutations() == EWAKStarBBKStar.this.kstarSettings.numMutations) {
                    s.fillWildType();
                    children.add(new SingleSequenceNode(s, this.confDBs));
                    continue;
                }
                if (s.countMutations() == EWAKStarBBKStar.this.kstarSettings.numMutations) {
                    s.fillWildType();
                    children.add(new SingleSequenceNode(s, this.confDBs));
                    continue;
                }
                children.add(new MultiSequenceNode(s, this.confDBs));
            }
            return children;
        }

        private Set<String> getResTypeList(List<SeqSpace.ResType> resTypes) {
            HashSet<String> resTypeList = new HashSet<String>();
            for (SeqSpace.ResType r : resTypes) {
                resTypeList.add(r.name);
            }
            return resTypeList;
        }

        private Set<String> filterOnPreviousSeqs() {
            String subSeq = this.sequence.toString();
            Set<String> resTypes = subSeq.equals("") ? EWAKStarBBKStar.this.bbkstarSettings.allowedSeqs.getFirstPos() : EWAKStarBBKStar.this.bbkstarSettings.allowedSeqs.getSeq(subSeq);
            return resTypes;
        }

        @Override
        public void estimateScore() {
            int numConfs = 1000;
            BigDecimal proteinLowerBound = this.calcLowerBound(EWAKStarBBKStar.this.protein, this.sequence, 1000);
            if (MathTools.isZero(proteinLowerBound)) {
                this.score = Double.POSITIVE_INFINITY;
                this.isUnboundUnstable = false;
                return;
            }
            BigDecimal ligandLowerBound = this.calcLowerBound(EWAKStarBBKStar.this.ligand, this.sequence, 1000);
            if (MathTools.isZero(ligandLowerBound)) {
                this.score = Double.POSITIVE_INFINITY;
                this.isUnboundUnstable = false;
                return;
            }
            BigDecimal complexUpperBound = this.calcUpperBound(EWAKStarBBKStar.this.complex, this.sequence, 1000);
            this.score = MathTools.bigDivideDivide(complexUpperBound, proteinLowerBound, ligandLowerBound, PartitionFunction.decimalPrecision).doubleValue();
            this.isUnboundUnstable = false;
        }

        private BigDecimal calcLowerBound(ConfSpaceInfo info2, Sequence sequence, int numConfs) {
            ConfSearch.ScoredConf conf;
            RCs rcs = sequence.makeRCs(info2.confSpace);
            ConfSearch astar = info2.confTreeFactoryRigid.apply(rcs);
            BoltzmannCalculator bcalc = new BoltzmannCalculator(PartitionFunction.decimalPrecision);
            BigMath m = new BigMath(PartitionFunction.decimalPrecision).set(0.0);
            for (int i = 0; i < numConfs && (conf = astar.nextConf()) != null; ++i) {
                m.add(bcalc.calc(conf.getScore()));
            }
            return m.get();
        }

        private BigDecimal calcUpperBound(ConfSpaceInfo info2, Sequence sequence, int numConfs) {
            RCs rcs = sequence.makeRCs(info2.confSpace);
            UpperBoundCalculator calc2 = new UpperBoundCalculator(info2.confTreeFactoryMinimized.apply(rcs), rcs.getNumConformations());
            calc2.run(numConfs);
            return calc2.totalBound;
        }

        public String toString() {
            return String.format("MultiSequenceNode[score=%12.6f, seq=%s]", this.score, this.sequence);
        }
    }

    private abstract class Node
    implements Comparable<Node> {
        public final Sequence sequence;
        public final ConfDB.DBs confDBs;
        public boolean isWTSeq = false;
        public double score;
        public boolean isUnboundUnstable;

        protected Node(EWAKStarBBKStar eWAKStarBBKStar, Sequence sequence, ConfDB.DBs confDBs) {
            if (sequence.isWildType()) {
                this.isWTSeq = true;
            }
            this.sequence = sequence;
            this.confDBs = confDBs;
            this.score = 0.0;
        }

        @Override
        public int compareTo(Node other) {
            return -Double.compare(this.score, other.score);
        }

        public abstract void estimateScore();
    }

    public class SingleSequenceNode
    extends Node {
        public EWAKStarPartitionFunction protein;
        public EWAKStarPartitionFunction ligand;
        public EWAKStarPartitionFunction complex;

        public SingleSequenceNode(Sequence sequence, ConfDB.DBs confDBs) {
            super(EWAKStarBBKStar.this, sequence, confDBs);
            this.protein = this.makePfunc(EWAKStarBBKStar.this.proteinPfuncs, EWAKStarBBKStar.this.protein, confDBs.get(EWAKStarBBKStar.this.protein.confSpace));
            this.ligand = this.makePfunc(EWAKStarBBKStar.this.ligandPfuncs, EWAKStarBBKStar.this.ligand, confDBs.get(EWAKStarBBKStar.this.ligand.confSpace));
            this.complex = this.makePfunc(EWAKStarBBKStar.this.complexPfuncs, EWAKStarBBKStar.this.complex, confDBs.get(EWAKStarBBKStar.this.complex.confSpace));
        }

        private EWAKStarPartitionFunction makePfunc(Map<Sequence, EWAKStarPartitionFunction> pfuncCache, ConfSpaceInfo info2, ConfDB confdb) {
            Sequence sequence = this.sequence.filter(info2.confSpace.seqSpace);
            EWAKStarPartitionFunction pfunc = pfuncCache.get(sequence);
            if (pfunc != null) {
                return pfunc;
            }
            pfunc = new EWAKStarGradientDescentPfunc(info2.confEcalcMinimized);
            pfunc.setReportProgress(EWAKStarBBKStar.this.kstarSettings.showPfuncProgress);
            if (confdb != null) {
                EWAKStarPartitionFunction.WithConfTable.setOrThrow(pfunc, confdb.getSequence(sequence));
            }
            RCs rcs = sequence.makeRCs(info2.confSpace);
            if (EWAKStarBBKStar.this.kstarSettings.useExternalMemory) {
                EWAKStarPartitionFunction.WithExternalMemory.setOrThrow(pfunc, true, rcs);
            }
            pfunc.init(EWAKStarBBKStar.this.confTrees.make(() -> info2.confTreeFactoryMinimized.apply(rcs)), EWAKStarBBKStar.this.confTrees.make(() -> info2.confTreeFactoryMinimized.apply(rcs)), rcs.getNumConformations(), EWAKStarBBKStar.this.kstarSettings.epsilon, EWAKStarBBKStar.this.kstarSettings.eW, EWAKStarBBKStar.this.kstarSettings.maxPFConfs, EWAKStarBBKStar.this.kstarSettings.printPDBs);
            pfunc.setStabilityThreshold(info2.stabilityThreshold);
            pfuncCache.put(sequence, pfunc);
            return pfunc;
        }

        @Override
        public void estimateScore() {
            if (this.protein.getStatus() == EWAKStarPartitionFunction.Status.Unstable || this.ligand.getStatus() == EWAKStarPartitionFunction.Status.Unstable) {
                this.score = Double.NEGATIVE_INFINITY;
                this.isUnboundUnstable = true;
                return;
            }
            if (this.protein.getStatus().canContinue()) {
                this.protein.compute(EWAKStarBBKStar.this.bbkstarSettings.numConfsPerBatch);
                if (this.protein.getStatus() == EWAKStarPartitionFunction.Status.Unstable) {
                    this.score = Double.NEGATIVE_INFINITY;
                    this.isUnboundUnstable = true;
                    return;
                }
            }
            if (this.ligand.getStatus().canContinue()) {
                this.ligand.compute(EWAKStarBBKStar.this.bbkstarSettings.numConfsPerBatch);
                if (this.ligand.getStatus() == EWAKStarPartitionFunction.Status.Unstable) {
                    this.score = Double.NEGATIVE_INFINITY;
                    this.isUnboundUnstable = true;
                    return;
                }
            }
            if (this.complex.getStatus().canContinue()) {
                this.complex.compute(EWAKStarBBKStar.this.bbkstarSettings.numConfsPerBatch);
            }
            this.score = Math.log10(this.makeKStarScore().upperBound.doubleValue());
            this.isUnboundUnstable = false;
            if (this.getStatus() == PfuncsStatus.Blocked && this.score == Double.POSITIVE_INFINITY) {
                this.score = Double.NEGATIVE_INFINITY;
            }
        }

        public EWAKStarScore makeKStarScore() {
            return new EWAKStarScore(this.protein.makeResult(), this.ligand.makeResult(), this.complex.makeResult());
        }

        public PfuncsStatus getStatus() {
            if (!(this.protein.getStatus() != EWAKStarPartitionFunction.Status.ConfLimitReached && this.protein.getStatus() != EWAKStarPartitionFunction.Status.EpsilonReached && this.protein.getStatus() != EWAKStarPartitionFunction.Status.EnergyReached || this.ligand.getStatus() != EWAKStarPartitionFunction.Status.ConfLimitReached && this.ligand.getStatus() != EWAKStarPartitionFunction.Status.EpsilonReached && this.ligand.getStatus() != EWAKStarPartitionFunction.Status.EnergyReached || this.complex.getStatus() != EWAKStarPartitionFunction.Status.ConfLimitReached && this.complex.getStatus() != EWAKStarPartitionFunction.Status.EpsilonReached && this.complex.getStatus() != EWAKStarPartitionFunction.Status.EnergyReached)) {
                return PfuncsStatus.Estimated;
            }
            if (this.protein.getStatus() == EWAKStarPartitionFunction.Status.Estimating || this.ligand.getStatus() == EWAKStarPartitionFunction.Status.Estimating || this.complex.getStatus() == EWAKStarPartitionFunction.Status.Estimating) {
                return PfuncsStatus.Estimating;
            }
            return PfuncsStatus.Blocked;
        }

        public String toString() {
            return String.format("SingleSequenceNode[score=%12.6f, seq=%s, K*=%s]", this.score, this.sequence, this.makeKStarScore());
        }
    }

    public static enum PfuncsStatus {
        Estimating,
        Estimated,
        Blocked;

    }

    private static interface DBsUser {
        public void use(ConfDBs var1);
    }

    private class ConfDBs {
        public ConfDB protein = null;
        public ConfDB ligand = null;
        public ConfDB complex = null;

        private ConfDBs(EWAKStarBBKStar eWAKStarBBKStar) {
        }
    }
}

