/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.kstar;

import edu.duke.cs.osprey.astar.conf.ConfAStarTree;
import edu.duke.cs.osprey.astar.conf.ConfSearchCache;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.astar.seq.RTs;
import edu.duke.cs.osprey.astar.seq.nodes.SeqAStarNode;
import edu.duke.cs.osprey.astar.seq.scoring.SeqAStarScorer;
import edu.duke.cs.osprey.confspace.ConfDB;
import edu.duke.cs.osprey.confspace.FragmentEnergies;
import edu.duke.cs.osprey.confspace.SeqSpace;
import edu.duke.cs.osprey.confspace.Sequence;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.kstar.KStar;
import edu.duke.cs.osprey.kstar.pfunc.BoltzmannCalculator;
import edu.duke.cs.osprey.kstar.pfunc.LowerBoundCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.kstar.pfunc.UpperBoundCalculator;
import edu.duke.cs.osprey.tools.HashCalculator;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.MathTools;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

public class MSKStar {
    public final LMFE objective;
    public final List<LMFE> constraints;
    public final double epsilon;
    public final double objectiveWindowSize;
    public final double objectiveWindowMax;
    public final int maxSimultaneousMutations;
    public final Integer minNumConfTrees;
    public final boolean printToConsole;
    public final File logFile;
    public final List<State> states;
    public final SeqSpace seqSpace;
    private final Map<StateConfs.Key, StateConfs> stateConfsCache = new HashMap<StateConfs.Key, StateConfs>();
    private final ConfSearchCache confTrees;

    private MSKStar(LMFE objective, List<LMFE> constraints, double epsilon, double objectiveWindowSize, double objectiveWindowMax, int maxSimultaneousMutations, Integer minNumConfTrees, boolean printToConsole, File logFile) {
        this.objective = objective;
        this.constraints = constraints;
        this.epsilon = epsilon;
        this.objectiveWindowSize = objectiveWindowSize;
        this.objectiveWindowMax = objectiveWindowMax;
        this.maxSimultaneousMutations = maxSimultaneousMutations;
        this.minNumConfTrees = minNumConfTrees;
        this.printToConsole = printToConsole;
        this.logFile = logFile;
        LinkedHashSet<State> statesSet = new LinkedHashSet<State>();
        for (WeightedState wstate : objective.states) {
            statesSet.add(wstate.state);
        }
        for (LMFE constraint : constraints) {
            for (WeightedState wstate : constraint.states) {
                statesSet.add(wstate.state);
            }
        }
        this.states = new ArrayList<State>(statesSet);
        if (this.states.isEmpty()) {
            throw new IllegalArgumentException("MSK* found no states");
        }
        this.seqSpace = SeqSpace.union(this.states.stream().map(state -> state.confSpace.seqSpace).collect(Collectors.toList()));
        this.confTrees = new ConfSearchCache(minNumConfTrees);
        this.log("sequence space has %s sequences\n%s", Log.formatBig(new RTs(this.seqSpace).getNumSequences()), this.seqSpace);
    }

    public List<SequenceInfo> findBestSequences(int numSequences) {
        throw new UnsupportedOperationException("This MSK* implementation doesn't work yet, don't use it!");
    }

    private void log(String msg, Object ... args) {
        if (this.printToConsole) {
            Log.log(msg, args);
        }
    }

    /*
     * WARNING - void declaration
     */
    private void reportSequence(boolean isFirstSequence, SequenceInfo info2) {
        if (this.printToConsole) {
            int cellSize = info2.sequence.calcCellSize();
            this.log("\nSequence calculation complete: %s    objective: %s   %s\n                               %s", info2.sequence.toString(Sequence.Renderer.ResNum, cellSize), info2.objective.toString(4, 9), info2.sequence.isWildType() ? "Wild-type" : "", info2.sequence.toString(Sequence.Renderer.ResTypeMutations, cellSize));
            for (State state : this.states) {
                this.log("\tState: %-20s    Free Energy: %s", state.name, info2.pfuncResults.get((Object)state).values.calcFreeEnergyBounds().toString(4, 9));
            }
        }
        if (this.logFile != null) {
            void var5_14;
            StringBuilder header = null;
            if (isFirstSequence) {
                header = new StringBuilder();
                for (SeqSpace.Position position : this.seqSpace.positions) {
                    if (position.index > 0) {
                        header.append("\t");
                    }
                    header.append(position.resNum);
                }
                header.append("\tObjective Min");
                header.append("\tObjective Max");
                for (int i = 0; i < this.constraints.size(); ++i) {
                    header.append("\tConstraint ");
                    header.append(i);
                    header.append(" Min\tConstraint ");
                    header.append(i);
                    header.append(" Max");
                }
                for (State state : this.states) {
                    header.append("\t");
                    header.append(state.name);
                    header.append(" Free Energy Min\t");
                    header.append(state.name);
                    header.append(" Free Energy Max");
                }
            }
            StringBuilder buf = new StringBuilder();
            for (SeqSpace.Position pos : this.seqSpace.positions) {
                if (pos.index > 0) {
                    buf.append("\t");
                }
                buf.append(info2.sequence.get(pos).mutationName());
            }
            buf.append(String.format("\t%.4f\t%.4f", info2.objective.lower, info2.objective.upper));
            boolean bl = false;
            while (var5_14 < this.constraints.size()) {
                MathTools.DoubleBounds constraint = info2.constraints.get(this.constraints.get((int)var5_14));
                buf.append(String.format("\t%.4f\t%.4f", constraint.lower, constraint.upper));
                ++var5_14;
            }
            for (State state : this.states) {
                PartitionFunction.Result pfuncResult = info2.pfuncResults.get(state);
                buf.append(String.format("\t%.4f\t%.4f", pfuncResult.values.calcFreeEnergyLowerBound(), pfuncResult.values.calcFreeEnergyUpperBound()));
            }
            try (FileWriter fileWriter = new FileWriter(this.logFile, !isFirstSequence);){
                if (header != null) {
                    fileWriter.write(header.toString());
                    fileWriter.write("\n");
                }
                fileWriter.write(buf.toString());
                fileWriter.write("\n");
            }
            catch (IOException iOException) {
                iOException.printStackTrace(System.err);
                if (header != null) {
                    System.err.println(header);
                }
                System.err.println(buf);
            }
        }
    }

    public static class LMFE {
        public final double offset;
        public final List<WeightedState> states;

        public LMFE(double offset, List<WeightedState> states) {
            this.offset = offset;
            this.states = states;
        }

        private MathTools.DoubleBounds calc(SeqConfs confs) {
            MathTools.DoubleBounds bound = new MathTools.DoubleBounds(this.offset, this.offset);
            for (WeightedState wstate : this.states) {
                StateConfs stateConfs = confs.statesConfs.get(wstate.state);
                if (wstate.weight > 0.0) {
                    bound.lower += wstate.weight * stateConfs.freeEnergyBounds.lower;
                    bound.upper += wstate.weight * stateConfs.freeEnergyBounds.upper;
                    continue;
                }
                bound.lower += wstate.weight * stateConfs.freeEnergyBounds.upper;
                bound.upper += wstate.weight * stateConfs.freeEnergyBounds.lower;
            }
            return bound;
        }

        public MathTools.DoubleBounds calc(Map<State, MathTools.DoubleBounds> stateFreeEnergies) {
            MathTools.DoubleBounds bound = new MathTools.DoubleBounds(this.offset, this.offset);
            for (WeightedState wstate : this.states) {
                MathTools.DoubleBounds freeEnergyBounds = stateFreeEnergies.get(wstate.state);
                if (wstate.weight > 0.0) {
                    bound.lower += wstate.weight * freeEnergyBounds.lower;
                    bound.upper += wstate.weight * freeEnergyBounds.upper;
                    continue;
                }
                bound.lower += wstate.weight * freeEnergyBounds.upper;
                bound.upper += wstate.weight * freeEnergyBounds.lower;
            }
            return bound;
        }

        public static class Builder {
            private double offset = 0.0;
            private final List<WeightedState> wstates = new ArrayList<WeightedState>();

            public Builder setOffset(double val) {
                this.offset = val;
                return this;
            }

            public Builder constrainLessThan(double val) {
                return this.setOffset(-val);
            }

            public Builder addState(State state, double weight) {
                this.wstates.add(new WeightedState(state, weight));
                return this;
            }

            public LMFE build() {
                return new LMFE(this.offset, this.wstates);
            }
        }
    }

    public static class WeightedState {
        public final State state;
        public final double weight;

        public WeightedState(State state, double weight) {
            this.state = state;
            this.weight = weight;
        }

        public double getSingleEnergy(int pos, int rc) {
            return Math.abs(this.weight) * this.state.fragmentEnergies.getEnergy(pos, rc);
        }

        public double getPairEnergy(int pos1, int rc1, int pos2, int rc2) {
            return Math.abs(this.weight) * this.state.fragmentEnergies.getEnergy(pos1, rc1, pos2, rc2);
        }
    }

    public static class State {
        public KStar.PfuncFactory pfuncFactory;
        public final String name;
        public final SimpleConfSpace confSpace;
        public FragmentEnergies fragmentEnergies;
        public ConfEnergyCalculator confEcalc;
        public Function<RCs, ConfAStarTree> confTreeFactory;
        public File confDBFile = null;

        public State(String name, SimpleConfSpace confSpace) {
            this.name = name;
            this.confSpace = confSpace;
        }

        public void checkConfig() {
            if (this.fragmentEnergies == null) {
                throw new InitException(this, "fragmentEnergies");
            }
            if (this.confEcalc == null) {
                throw new InitException(this, "confEcalc");
            }
            if (this.pfuncFactory == null) {
                throw new InitException(this, "pfuncFactory");
            }
        }

        public static class InitException
        extends RuntimeException {
            public InitException(State state, String name) {
                super(String.format("set %s for state %s before running", name, state.name));
            }
        }
    }

    public class SequenceInfo {
        public final Sequence sequence;
        public final Map<State, PartitionFunction.Result> pfuncResults = new HashMap<State, PartitionFunction.Result>();
        public final MathTools.DoubleBounds objective;
        public final Map<LMFE, MathTools.DoubleBounds> constraints = new HashMap<LMFE, MathTools.DoubleBounds>();

        public SequenceInfo(MSKStar this$0, SeqAStarNode node, SeqConfs confs) {
            this.sequence = node.makeSequence(this$0.seqSpace);
            for (State state : this$0.states) {
                this.pfuncResults.put(state, confs.statesConfs.get((Object)state).pfuncResult);
            }
            this.objective = this$0.objective.calc(confs);
            for (LMFE constraint : this$0.constraints) {
                this.constraints.put(constraint, constraint.calc(confs));
            }
        }
    }

    public static class Builder {
        private final LMFE objective;
        private final List<LMFE> constraints = new ArrayList<LMFE>();
        private double epsilon = 0.68;
        private double objectiveWindowSize = 10.0;
        private double objectiveWindowMax = 0.0;
        private int maxSimultaneousMutations = 1;
        private Integer minNumConfTrees = null;
        private boolean printToConsole = true;
        private File logFile = null;
        public EnergyMatrix rigidEmat;
        public EnergyMatrix minimizingEmat;

        public Builder(LMFE objective) {
            this.objective = objective;
        }

        public Builder addConstraint(LMFE constraint) {
            this.constraints.add(constraint);
            return this;
        }

        public Builder setEpsilon(double val) {
            this.epsilon = val;
            return this;
        }

        public Builder setObjectiveWindowSize(double val) {
            this.objectiveWindowSize = val;
            return this;
        }

        public Builder setObjectiveWindowMax(double val) {
            this.objectiveWindowMax = val;
            return this;
        }

        public Builder setMaxSimultaneousMutations(int val) {
            this.maxSimultaneousMutations = val;
            return this;
        }

        public Builder setMinNumConfTrees(Integer val) {
            this.minNumConfTrees = val;
            return this;
        }

        public Builder setPrintToConsole(boolean val) {
            this.printToConsole = val;
            return this;
        }

        public Builder setLogFile(File val) {
            this.logFile = val;
            return this;
        }

        public MSKStar build() {
            MSKStar mskStar = new MSKStar(this.objective, this.constraints, this.epsilon, this.objectiveWindowSize, this.objectiveWindowMax, this.maxSimultaneousMutations, this.minNumConfTrees, this.printToConsole, this.logFile);
            return mskStar;
        }
    }

    private class SeqConfs {
        final Map<State, StateConfs> statesConfs = new HashMap<State, StateConfs>();

        SeqConfs(SeqAStarNode seqNode, ConfDBs confDBs) {
            for (State state : MSKStar.this.states) {
                Sequence sequence = seqNode.makeSequence(MSKStar.this.seqSpace).filter(state.confSpace.seqSpace);
                StateConfs.Key key = new StateConfs.Key(sequence, state);
                StateConfs stateConfs = MSKStar.this.stateConfsCache.get(key);
                if (stateConfs == null) {
                    stateConfs = new StateConfs(sequence, state, MSKStar.this.epsilon, confDBs.tables.get(state), MSKStar.this.confTrees);
                    MSKStar.this.stateConfsCache.put(key, stateConfs);
                }
                this.statesConfs.put(state, stateConfs);
            }
        }

        boolean isRefinementComplete() {
            for (State state : MSKStar.this.states) {
                if (this.statesConfs.get(state).isRefinementComplete()) continue;
                return false;
            }
            return true;
        }

        public MathTools.DoubleBounds refineBounds() {
            for (State state : MSKStar.this.states) {
                this.statesConfs.get(state).refineBounds();
            }
            for (LMFE constraint : MSKStar.this.constraints) {
                if (!(constraint.calc((SeqConfs)this).lower > 0.0)) continue;
                return new MathTools.DoubleBounds(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
            }
            return MSKStar.this.objective.calc(this);
        }
    }

    private static class StateConfs {
        final State state;
        final Sequence sequence;
        final MathTools.DoubleBounds freeEnergyBounds = new MathTools.DoubleBounds();
        PartitionFunction pfunc = null;
        PartitionFunction.Result pfuncResult = null;

        StateConfs(Sequence sequence, State state, double epsilon, ConfDB.ConfTable confTable, ConfSearchCache confTrees) {
            this.state = state;
            this.sequence = sequence;
            RCs rcs = sequence.makeRCs(state.confSpace);
            this.pfunc = state.pfuncFactory.make(rcs);
            this.pfunc.init(epsilon);
        }

        void refineBounds() {
            if (this.pfuncResult != null) {
                return;
            }
            this.pfunc.compute(this.state.confEcalc.tasks.getParallelism());
            this.pfunc.getValues().calcFreeEnergyBounds(this.freeEnergyBounds);
            if (!this.pfunc.getStatus().canContinue()) {
                this.pfuncResult = this.pfunc.makeResult();
                this.pfunc = null;
            }
        }

        boolean isRefinementComplete() {
            return this.pfuncResult != null;
        }

        private static class Key {
            final Sequence sequence;
            final State state;

            Key(Sequence sequence, State state) {
                this.sequence = sequence;
                this.state = state;
            }

            public int hashCode() {
                return HashCalculator.combineHashes(this.sequence.hashCode(), this.state.hashCode());
            }

            public boolean equals(Object other) {
                return other instanceof Key && this.equals((Key)other);
            }

            public boolean equals(Key other) {
                return this.sequence.equals(other.sequence) && this.state == other.state;
            }
        }
    }

    private class ConfDBs
    extends ConfDB.DBs {
        public Map<State, ConfDB.ConfTable> tables = new HashMap<State, ConfDB.ConfTable>();

        public ConfDBs(MSKStar mSKStar) {
            for (State state : mSKStar.states) {
                this.add(state.confSpace, state.confDBFile);
            }
            for (State state : mSKStar.states) {
                ConfDB confdb = this.get(state.confSpace);
                if (confdb == null) continue;
                ConfDB confDB = confdb;
                Objects.requireNonNull(confDB);
                this.tables.put(state, new ConfDB.ConfTable(confDB, "MSK*"));
            }
        }
    }

    private class SeqHScorer
    implements SeqAStarScorer {
        private static final int upperBatchSize = 1000;
        private static final int numLowerBatches = 1;
        final ConfDBs confDBs;
        BoltzmannCalculator bcalc = new BoltzmannCalculator(PartitionFunction.decimalPrecision);

        SeqHScorer(ConfDBs confDBs) {
            this.confDBs = confDBs;
        }

        @Override
        public double calc(SeqAStarNode.Assignments assignments) {
            double lowerBound = 0.0;
            for (WeightedState wstate : MSKStar.this.objective.states) {
                RCs rcs = assignments.makeRCs(MSKStar.this.seqSpace, wstate.state.confSpace);
                ConfAStarTree confTree = wstate.state.confTreeFactory.apply(rcs);
                if (wstate.weight > 0.0) {
                    UpperBoundCalculator ubcalc = new UpperBoundCalculator(confTree, rcs.getNumConformations());
                    ubcalc.run(1000);
                    lowerBound += wstate.weight * this.bcalc.freeEnergy(ubcalc.totalBound);
                    continue;
                }
                LowerBoundCalculator lbcalc = new LowerBoundCalculator(confTree, wstate.state.confEcalc);
                lbcalc.confTable = this.confDBs.tables.get(wstate.state);
                for (int i = 0; i < 1; ++i) {
                    lbcalc.run(wstate.state.confEcalc.tasks.getParallelism());
                }
                lowerBound += wstate.weight * this.bcalc.freeEnergy(lbcalc.weightedEnergySum);
            }
            return lowerBound;
        }
    }
}

