/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.gmec;

import edu.duke.cs.osprey.astar.conf.ConfAStarTree;
import edu.duke.cs.osprey.astar.conf.ConfSearchCache;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.astar.seq.RTs;
import edu.duke.cs.osprey.astar.seq.SeqAStarTree;
import edu.duke.cs.osprey.astar.seq.nodes.SeqAStarNode;
import edu.duke.cs.osprey.astar.seq.order.SequentialSeqAStarOrder;
import edu.duke.cs.osprey.astar.seq.scoring.NOPSeqAStarScorer;
import edu.duke.cs.osprey.astar.seq.scoring.SeqAStarScorer;
import edu.duke.cs.osprey.confspace.Conf;
import edu.duke.cs.osprey.confspace.ConfDB;
import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.confspace.FragmentEnergies;
import edu.duke.cs.osprey.confspace.SeqSpace;
import edu.duke.cs.osprey.confspace.Sequence;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.tools.HashCalculator;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.MathTools;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

public class Comets {
    public final LME objective;
    public final List<LME> constraints;
    public final double objectiveWindowSize;
    public final double objectiveWindowMax;
    public final int maxSimultaneousMutations;
    public final Integer minNumConfTrees;
    public final boolean printToConsole;
    public final File logFile;
    public final List<State> states;
    public final SeqSpace seqSpace;
    private final Map<StateConfs.Key, StateConfs> stateConfsCache = new HashMap<StateConfs.Key, StateConfs>();
    private final ConfSearchCache confTrees;

    private Comets(LME objective, List<LME> constraints, double objectiveWindowSize, double objectiveWindowMax, int maxSimultaneousMutations, Integer minNumConfTrees, boolean printToConsole, File logFile) {
        this.objective = objective;
        this.constraints = constraints;
        this.objectiveWindowSize = objectiveWindowSize;
        this.objectiveWindowMax = objectiveWindowMax;
        this.maxSimultaneousMutations = maxSimultaneousMutations;
        this.minNumConfTrees = minNumConfTrees;
        this.printToConsole = printToConsole;
        this.logFile = logFile;
        LinkedHashSet<State> statesSet = new LinkedHashSet<State>();
        for (WeightedState wstate : objective.states) {
            statesSet.add(wstate.state);
        }
        for (LME constraint : constraints) {
            for (WeightedState wstate : constraint.states) {
                statesSet.add(wstate.state);
            }
        }
        this.states = new ArrayList<State>(statesSet);
        if (this.states.isEmpty()) {
            throw new IllegalArgumentException("COMETS found no states");
        }
        this.seqSpace = SeqSpace.union(this.states.stream().map(state -> state.confSpace.seqSpace).collect(Collectors.toList()));
        this.confTrees = new ConfSearchCache(minNumConfTrees);
        this.log("sequence space has %s sequences\n%s", Log.formatBig(new RTs(this.seqSpace).getNumSequences()), this.seqSpace);
    }

    public List<SequenceInfo> findBestSequences(int numSequences) {
        this.stateConfsCache.clear();
        for (State state : this.states) {
            state.checkConfig();
        }
        SeqAStarTree seqTree = new SeqAStarTree.Builder(new RTs(this.seqSpace)).setHeuristics(new SequentialSeqAStarOrder(), new NOPSeqAStarScorer(), new SeqHScorer()).setNumMutable(this.maxSimultaneousMutations).build();
        ArrayList<SequenceInfo> infos2 = new ArrayList<SequenceInfo>();
        this.log("\nCOMETS searching for the %d best sequences among %s with up to %d simultaneous mutations ...", numSequences, Log.formatBig(new RTs(this.seqSpace).getNumSequences()), this.maxSimultaneousMutations);
        this.log("(up to objective function value %.6f kcal/mol, or +%.6f kcal/mol relative to the best sequence)", this.objectiveWindowMax, this.objectiveWindowSize);
        this.log("", new Object[0]);
        try (ConfDBs confDBs = new ConfDBs(this);){
            SeqAStarNode node;
            while ((node = seqTree.nextLeafNode()) != null) {
                if (node.getScore() > this.objectiveWindowMax || !infos2.isEmpty() && node.getScore() > ((SequenceInfo)infos2.get((int)0)).objective + this.objectiveWindowSize) {
                    this.log("\nCOMETS exiting early: exhausted all conformations in energy window", new Object[0]);
                    break;
                }
                SeqConfs confs = (SeqConfs)node.getData();
                if (confs == null) {
                    this.log("Discovered promising sequence: %s   objective lower bound: %12.6f", node.makeSequence(this.seqSpace), node.getScore());
                    confs = new SeqConfs(node);
                    node.setData(confs);
                }
                if (confs.hasAllGMECs()) {
                    SequenceInfo info2 = new SequenceInfo(this, node, confs);
                    infos2.add(info2);
                    this.reportSequence(infos2.size() == 1, info2);
                    if (infos2.size() < numSequences) continue;
                    break;
                }
                node.setHScore(confs.refineBounds(confDBs));
                if (node.getScore() == Double.POSITIVE_INFINITY) continue;
                seqTree.add(node);
            }
        }
        this.log("", new Object[0]);
        if (infos2.isEmpty()) {
            this.log("COMETS didn't find any sequences within the window that satisfy all the constraints.", new Object[0]);
        } else {
            this.log("COMETS found the best %d within the window that satisfy all the constraints", infos2.size());
        }
        return infos2;
    }

    private void log(String msg, Object ... args) {
        if (this.printToConsole) {
            Log.log(msg, args);
        }
    }

    /*
     * WARNING - void declaration
     */
    private void reportSequence(boolean isFirstSequence, SequenceInfo info2) {
        if (this.printToConsole) {
            int cellSize = info2.sequence.calcCellSize();
            this.log("\nSequence calculation complete: %s    objective: %12.6f   %s\n                               %s", info2.sequence.toString(Sequence.Renderer.ResNum, cellSize), info2.objective, info2.sequence.isWildType() ? "Wild-type" : "", info2.sequence.toString(Sequence.Renderer.ResTypeMutations, cellSize));
            for (State state : this.states) {
                this.log("\tState: %-20s    GMEC Energy: %12.6f", state.name, info2.GMECs.get(state).getEnergy());
            }
        }
        if (this.logFile != null) {
            void var5_14;
            StringBuilder header = null;
            if (isFirstSequence) {
                header = new StringBuilder();
                for (SeqSpace.Position position : this.seqSpace.positions) {
                    if (position.index > 0) {
                        header.append("\t");
                    }
                    header.append(position.resNum);
                }
                header.append("\tObjective");
                for (int i = 0; i < this.constraints.size(); ++i) {
                    header.append("\tConstraint ");
                    header.append(i);
                }
                for (State state : this.states) {
                    header.append("\t");
                    header.append(state.name);
                    header.append(" GMEC Energy\t");
                    header.append(state.name);
                    header.append(" GMEC Conf");
                }
            }
            StringBuilder buf = new StringBuilder();
            for (SeqSpace.Position pos : this.seqSpace.positions) {
                if (pos.index > 0) {
                    buf.append("\t");
                }
                buf.append(info2.sequence.get(pos).mutationName());
            }
            buf.append("\t");
            buf.append(String.format("%.6f", info2.objective));
            boolean bl = false;
            while (var5_14 < this.constraints.size()) {
                buf.append("\t");
                buf.append(String.format("%.6f", info2.constraints.get(this.constraints.get((int)var5_14))));
                ++var5_14;
            }
            for (State state : this.states) {
                ConfSearch.EnergiedConf gmec = info2.GMECs.get(state);
                buf.append("\t");
                buf.append(String.format("%.6f", gmec.getEnergy()));
                buf.append("\t");
                buf.append(Conf.toString(gmec.getAssignments()));
            }
            try (FileWriter fileWriter = new FileWriter(this.logFile, !isFirstSequence);){
                if (header != null) {
                    fileWriter.write(header.toString());
                    fileWriter.write("\n");
                }
                fileWriter.write(buf.toString());
                fileWriter.write("\n");
            }
            catch (IOException iOException) {
                iOException.printStackTrace(System.err);
                if (header != null) {
                    System.err.println(header);
                }
                System.err.println(buf);
            }
        }
    }

    private void dump(SeqAStarNode node) {
        SeqConfs confs = (SeqConfs)node.getData();
        this.log("sequence %s", node.makeSequence(this.seqSpace));
        this.log("\tscore: %.6f   completed? %b", node.getScore(), confs.hasAllGMECs());
        this.log("\tobjective: %.6f", this.objective.calc(confs));
        for (LME constraint : this.constraints) {
            this.log("\tconstraint: %.3f", constraint.calc(confs));
        }
        for (StateConfs stateConfs : confs.statesConfs.values()) {
            this.log("\tstate %-20s GMEC bounds [%8.3f,%8.3f]    found gmec? %b", stateConfs.state.name, stateConfs.minScoreConf.getScore(), stateConfs.minEnergyConf.getEnergy(), stateConfs.gmec != null);
        }
    }

    public static class LME {
        public final double offset;
        public final List<WeightedState> states;

        public LME(double offset, List<WeightedState> states) {
            this.offset = offset;
            this.states = states;
        }

        private double calc(SeqConfs confs) {
            double val = this.offset;
            for (WeightedState wstate : this.states) {
                StateConfs stateConfs = confs.statesConfs.get(wstate.state);
                if (wstate.weight > 0.0) {
                    val += wstate.weight * stateConfs.getObjectiveLowerBound();
                    continue;
                }
                val += wstate.weight * stateConfs.getObjectiveUpperBound();
            }
            return val;
        }

        public double calc(Map<State, Double> stateEnergies) {
            double val = this.offset;
            for (WeightedState wstate : this.states) {
                val += wstate.weight * stateEnergies.get(wstate.state);
            }
            return val;
        }

        public static class Builder {
            private double offset = 0.0;
            private final List<WeightedState> wstates = new ArrayList<WeightedState>();

            public Builder setOffset(double val) {
                this.offset = val;
                return this;
            }

            public Builder constrainLessThan(double val) {
                return this.setOffset(-val);
            }

            public Builder addState(State state, double weight) {
                this.wstates.add(new WeightedState(state, weight));
                return this;
            }

            public LME build() {
                return new LME(this.offset, this.wstates);
            }
        }
    }

    public static class WeightedState {
        public final State state;
        public final double weight;

        public WeightedState(State state, double weight) {
            this.state = state;
            this.weight = weight;
        }

        public double getSingleEnergy(int pos, int rc) {
            return Math.abs(this.weight) * this.state.fragmentEnergies.getEnergy(pos, rc);
        }

        public double getPairEnergy(int pos1, int rc1, int pos2, int rc2) {
            return Math.abs(this.weight) * this.state.fragmentEnergies.getEnergy(pos1, rc1, pos2, rc2);
        }
    }

    public static class State {
        public final String name;
        public final SimpleConfSpace confSpace;
        public FragmentEnergies fragmentEnergies;
        public ConfEnergyCalculator confEcalc;
        public Function<RCs, ConfAStarTree> confTreeFactory;
        public File confDBFile = null;

        public State(String name, SimpleConfSpace confSpace) {
            this.name = name;
            this.confSpace = confSpace;
        }

        public void checkConfig() {
            if (this.fragmentEnergies == null) {
                throw new InitException(this, "fragmentEnergies");
            }
            if (this.confEcalc == null) {
                throw new InitException(this, "confEcalc");
            }
            if (this.confTreeFactory == null) {
                throw new InitException(this, "confTreeFactory");
            }
        }

        public static class InitException
        extends RuntimeException {
            public InitException(State state, String name) {
                super(String.format("set %s for state %s before running", name, state.name));
            }
        }
    }

    private class SeqHScorer
    implements SeqAStarScorer {
        MathTools.Optimizer opt = MathTools.Optimizer.Minimize;
        List<SimpleConfSpace.Position> allPositions = new ArrayList<SimpleConfSpace.Position>();
        List<WeightedState> statesByAllPosition = new ArrayList<WeightedState>();

        SeqHScorer() {
            for (WeightedState wstate : Comets.this.objective.states) {
                for (SimpleConfSpace.Position pos : wstate.state.confSpace.positions) {
                    this.allPositions.add(pos);
                    this.statesByAllPosition.add(wstate);
                }
            }
        }

        @Override
        public double calc(SeqAStarNode.Assignments assignments) {
            double score = Comets.this.objective.offset;
            for (int i1 = 0; i1 < this.allPositions.size(); ++i1) {
                SimpleConfSpace.Position pos1 = this.allPositions.get(i1);
                double bestPos1Energy = this.opt.initDouble();
                for (SeqSpace.ResType rt1 : this.getRTs(pos1, assignments)) {
                    WeightedState wstate = this.statesByAllPosition.get(i1);
                    double bestRC1Energy = this.opt.initDouble();
                    for (SimpleConfSpace.ResidueConf rc1 : this.getRCs(pos1, rt1, wstate.state)) {
                        double rc1Energy = 0.0;
                        rc1Energy += wstate.getSingleEnergy(pos1.index, rc1.index);
                        for (int i2 = 0; i2 < pos1.index; ++i2) {
                            SimpleConfSpace.Position pos2 = wstate.state.confSpace.positions.get(i2);
                            double bestRT2Energy = this.opt.initDouble();
                            for (SeqSpace.ResType rt2 : this.getRTs(pos2, assignments)) {
                                double bestRC2Energy = this.opt.initDouble();
                                for (SimpleConfSpace.ResidueConf rc2 : this.getRCs(pos2, rt2, wstate.state)) {
                                    double rc2Energy = wstate.getPairEnergy(pos1.index, rc1.index, pos2.index, rc2.index);
                                    bestRC2Energy = this.opt.opt(bestRC2Energy, rc2Energy);
                                }
                                bestRT2Energy = this.opt.opt(bestRT2Energy, bestRC2Energy);
                            }
                            rc1Energy += bestRT2Energy;
                        }
                        bestRC1Energy = this.opt.opt(bestRC1Energy, rc1Energy);
                    }
                    bestPos1Energy = this.opt.opt(bestPos1Energy, bestRC1Energy);
                }
                score += bestPos1Energy;
            }
            return score;
        }

        List<SeqSpace.ResType> getRTs(SimpleConfSpace.Position confPos, SeqAStarNode.Assignments assignments) {
            SeqSpace.Position seqPos = Comets.this.seqSpace.getPosition(confPos.resNum);
            if (seqPos != null) {
                Integer assignedRT = assignments.getAssignment(seqPos.index);
                if (assignedRT != null) {
                    return Collections.singletonList(seqPos.resTypes.get(assignedRT));
                }
                return seqPos.resTypes;
            }
            assert (confPos.resTypes.size() == 1);
            return Collections.singletonList(null);
        }

        List<SimpleConfSpace.ResidueConf> getRCs(SimpleConfSpace.Position pos, SeqSpace.ResType rt, State state) {
            if (rt != null) {
                return pos.resConfs.stream().filter(rc -> rc.template.name.equals(rt.name)).collect(Collectors.toList());
            }
            return pos.resConfs;
        }
    }

    private class ConfDBs
    extends ConfDB.DBs {
        public Map<State, ConfDB.ConfTable> tables = new HashMap<State, ConfDB.ConfTable>();

        public ConfDBs(Comets comets) {
            for (State state : comets.states) {
                this.add(state.confSpace, state.confDBFile);
            }
            for (State state : comets.states) {
                ConfDB confdb = this.get(state.confSpace);
                if (confdb == null) continue;
                ConfDB confDB = confdb;
                Objects.requireNonNull(confDB);
                this.tables.put(state, new ConfDB.ConfTable(confDB, "COMETS"));
            }
        }
    }

    public class SequenceInfo {
        public final Sequence sequence;
        public final Map<State, ConfSearch.EnergiedConf> GMECs = new HashMap<State, ConfSearch.EnergiedConf>();
        public final double objective;
        public final Map<LME, Double> constraints = new HashMap<LME, Double>();

        public SequenceInfo(Comets this$0, SeqAStarNode node, SeqConfs confs) {
            this.sequence = node.makeSequence(this$0.seqSpace);
            for (State state : this$0.states) {
                this.GMECs.put(state, confs.statesConfs.get((Object)state).gmec);
            }
            this.objective = this$0.objective.calc(confs);
            for (LME constraint : this$0.constraints) {
                this.constraints.put(constraint, constraint.calc(confs));
            }
        }
    }

    private class SeqConfs {
        final Map<State, StateConfs> statesConfs = new HashMap<State, StateConfs>();

        SeqConfs(SeqAStarNode seqNode) {
            for (State state : Comets.this.states) {
                Sequence sequence = seqNode.makeSequence(Comets.this.seqSpace).filter(state.confSpace.seqSpace);
                StateConfs.Key key = new StateConfs.Key(sequence, state);
                StateConfs stateConfs = Comets.this.stateConfsCache.get(key);
                if (stateConfs == null) {
                    stateConfs = new StateConfs(sequence, state, Comets.this.confTrees);
                    Comets.this.stateConfsCache.put(key, stateConfs);
                }
                this.statesConfs.put(state, stateConfs);
            }
        }

        boolean hasAllGMECs() {
            for (State state : Comets.this.states) {
                if (this.statesConfs.get((Object)state).gmec != null) continue;
                return false;
            }
            return true;
        }

        public double refineBounds(ConfDBs confDBs) {
            for (State state : Comets.this.states) {
                this.statesConfs.get(state).refineBounds(confDBs.tables.get(state));
            }
            for (LME constraint : Comets.this.constraints) {
                if (!(constraint.calc(this) > 0.0)) continue;
                return Double.POSITIVE_INFINITY;
            }
            return Comets.this.objective.calc(this);
        }
    }

    private static class StateConfs {
        final State state;
        final Sequence sequence;
        ConfSearch confTree = null;
        ConfSearch.ScoredConf minScoreConf = null;
        ConfSearch.EnergiedConf minEnergyConf = null;
        ConfSearch.EnergiedConf gmec = null;
        List<ConfSearch.ScoredConf> confs = new ArrayList<ConfSearch.ScoredConf>();

        StateConfs(Sequence sequence, State state, ConfSearchCache confTrees) {
            this.state = state;
            this.sequence = sequence;
            RCs rcs = sequence.makeRCs(state.confSpace);
            this.confTree = confTrees.make(() -> state.confTreeFactory.apply(rcs));
        }

        void refineBounds(ConfDB.ConfTable confTable) {
            ConfSearch.ScoredConf conf2;
            if (this.gmec != null) {
                return;
            }
            this.confs.clear();
            for (int i = 0; i < this.state.confEcalc.tasks.getParallelism() && (conf2 = this.confTree.nextConf()) != null; ++i) {
                if (this.minScoreConf == null) {
                    this.minScoreConf = conf2;
                }
                this.confs.add(conf2);
            }
            if (this.confs.isEmpty()) {
                return;
            }
            for (ConfSearch.ScoredConf conf2 : this.confs) {
                this.state.confEcalc.calcEnergyAsync(conf2, confTable, econf -> {
                    if (this.minEnergyConf == null || econf.getEnergy() < this.minEnergyConf.getEnergy()) {
                        this.minEnergyConf = econf;
                    }
                });
            }
            this.state.confEcalc.tasks.waitForFinish();
            ConfSearch.ScoredConf maxScoreConf = this.confs.get(this.confs.size() - 1);
            if (maxScoreConf.getScore() >= this.minEnergyConf.getEnergy()) {
                this.gmec = this.minEnergyConf;
                this.confTree = null;
            }
            this.confs.clear();
        }

        double getObjectiveLowerBound() {
            if (this.gmec != null) {
                return this.gmec.getEnergy();
            }
            return this.minScoreConf.getScore();
        }

        double getObjectiveUpperBound() {
            if (this.gmec != null) {
                return this.gmec.getEnergy();
            }
            return this.minEnergyConf.getEnergy();
        }

        private static class Key {
            final Sequence sequence;
            final State state;

            Key(Sequence sequence, State state) {
                this.sequence = sequence;
                this.state = state;
            }

            public int hashCode() {
                return HashCalculator.combineHashes(this.sequence.hashCode(), this.state.hashCode());
            }

            public boolean equals(Object other) {
                return other instanceof Key && this.equals((Key)other);
            }

            public boolean equals(Key other) {
                return this.sequence.equals(other.sequence) && this.state == other.state;
            }
        }
    }

    public static class Builder {
        private final LME objective;
        private final List<LME> constraints = new ArrayList<LME>();
        private double objectiveWindowSize = 10.0;
        private double objectiveWindowMax = 0.0;
        private int maxSimultaneousMutations = 1;
        private Integer minNumConfTrees = null;
        private boolean printToConsole = true;
        private File logFile = null;

        public Builder(LME objective) {
            this.objective = objective;
        }

        public Builder addConstraint(LME constraint) {
            this.constraints.add(constraint);
            return this;
        }

        public Builder setObjectiveWindowSize(double val) {
            this.objectiveWindowSize = val;
            return this;
        }

        public Builder setObjectiveWindowMax(double val) {
            this.objectiveWindowMax = val;
            return this;
        }

        public Builder setMaxSimultaneousMutations(int val) {
            this.maxSimultaneousMutations = val;
            return this;
        }

        public Builder setMinNumConfTrees(Integer val) {
            this.minNumConfTrees = val;
            return this;
        }

        public Builder setPrintToConsole(boolean val) {
            this.printToConsole = val;
            return this;
        }

        public Builder setLogFile(File val) {
            this.logFile = val;
            return this;
        }

        public Comets build() {
            return new Comets(this.objective, this.constraints, this.objectiveWindowSize, this.objectiveWindowMax, this.maxSimultaneousMutations, this.minNumConfTrees, this.printToConsole, this.logFile);
        }
    }
}

