/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.paste;

import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.confspace.ConfDB;
import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.energy.EnergyCalculator;
import edu.duke.cs.osprey.ewakstar.EWAKStarPartitionFunction;
import edu.duke.cs.osprey.kstar.pfunc.BoltzmannCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PfuncSurface;
import edu.duke.cs.osprey.paste.PastePartitionFunction;
import edu.duke.cs.osprey.tools.BigMath;
import edu.duke.cs.osprey.tools.JvmMem;
import edu.duke.cs.osprey.tools.MathTools;
import edu.duke.cs.osprey.tools.Stopwatch;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class PasteGradientDescentPfunc
implements PastePartitionFunction.WithConfTable,
PastePartitionFunction.WithExternalMemory {
    public final ConfEnergyCalculator ecalc;
    private double targetEpsilon = Double.NaN;
    private double targetEnergy = 0.0;
    private BigDecimal stabilityThreshold = BigDecimal.ZERO;
    private PastePartitionFunction.ConfListener confListener = null;
    private boolean isReportingProgress = false;
    private Stopwatch stopwatch = new Stopwatch().start();
    private ConfSearch scoreConfs = null;
    private ConfSearch energyConfs = null;
    private BoltzmannCalculator bcalc = new BoltzmannCalculator(EWAKStarPartitionFunction.decimalPrecision);
    private PastePartitionFunction.Status status = null;
    private PastePartitionFunction.Values values = null;
    private State state = null;
    private boolean hasEnergyConfs = true;
    private boolean hasScoreConfs = true;
    private long numEnergyConfsEnumerated = 0L;
    private long numScoreConfsEnumerated = 0L;
    private boolean useWindowCriterion;
    private PastePartitionFunction.Result wtResult = null;
    private ConfDB.ConfTable confTable = null;
    private boolean useExternalMemory = false;
    private RCs rcs = null;
    private PfuncSurface surf = null;
    private PfuncSurface.Trace trace = null;

    public PasteGradientDescentPfunc(ConfEnergyCalculator ecalc) {
        this.ecalc = ecalc;
    }

    @Override
    public void setReportProgress(boolean val) {
        this.isReportingProgress = val;
    }

    @Override
    public void setConfListener(PastePartitionFunction.ConfListener val) {
        this.confListener = val;
    }

    @Override
    public PastePartitionFunction.Status getStatus() {
        return this.status;
    }

    @Override
    public ArrayList<EnergyCalculator.EnergiedParametricMolecule> getEpMols() {
        return this.state.epMols;
    }

    @Override
    public HashMap<Double, ConfSearch.ScoredConf> getSConfs() {
        return this.state.sConfs;
    }

    @Override
    public PastePartitionFunction.Values getValues() {
        return this.values;
    }

    @Override
    public int getNumConfsEvaluated() {
        return (int)this.state.numEnergiedConfs;
    }

    @Override
    public int getParallelism() {
        return this.ecalc.tasks.getParallelism();
    }

    @Override
    public void setConfTable(ConfDB.ConfTable val) {
        this.confTable = val;
    }

    @Override
    public void setUseExternalMemory(boolean val, RCs rcs) {
        this.useExternalMemory = val;
        this.rcs = rcs;
    }

    public void traceTo(PfuncSurface val) {
        this.surf = val;
    }

    @Override
    public void init(ConfSearch scoreConfs, ConfSearch energyConfs, BigInteger numConfsBeforePruning, double targetEpsilon, double targetEnergy, PastePartitionFunction.Result wtResult, boolean useWindowCriterion) {
        if (targetEpsilon <= 0.0 || targetEnergy < 0.0) {
            throw new IllegalArgumentException("target epsilon and target energy must be greater than zero");
        }
        this.useWindowCriterion = useWindowCriterion;
        this.wtResult = wtResult;
        this.energyConfs = energyConfs;
        this.targetEpsilon = targetEpsilon;
        this.targetEnergy = targetEnergy;
        this.status = PastePartitionFunction.Status.Estimating;
        this.state = new State(numConfsBeforePruning);
        this.values = PastePartitionFunction.Values.makeFullRange();
        this.values.pstar = BigDecimal.ZERO;
        this.hasEnergyConfs = true;
        this.hasScoreConfs = true;
        this.numEnergyConfsEnumerated = 0L;
        this.numScoreConfsEnumerated = 0L;
        this.scoreConfs = scoreConfs;
    }

    @Override
    public void setStabilityThreshold(BigDecimal val) {
        this.stabilityThreshold = val;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void compute(int maxNumConfs, int numPDBs) {
        if (this.status == null) {
            throw new IllegalStateException("pfunc was not initialized. Call init() before compute()");
        }
        if (!this.status.canContinue()) {
            return;
        }
        if (this.surf != null) {
            this.trace = new PfuncSurface.Trace(this.surf);
        }
        boolean keepStepping = true;
        int numConfsEnergied = 0;
        while (numConfsEnergied < maxNumConfs) {
            Step step = Step.None;
            int numScores = 0;
            PasteGradientDescentPfunc pasteGradientDescentPfunc = this;
            synchronized (pasteGradientDescentPfunc) {
                boolean energySteeperThanScore;
                if (this.wtResult != null && this.useWindowCriterion) {
                    keepStepping = keepStepping && !this.state.epsilonReached(this.targetEpsilon) && !this.state.noWindowOverlaps(this.wtResult, this.numEnergyConfsEnumerated) && this.state.isStable(this.stabilityThreshold) && this.state.hasLowEnergies() && !this.state.energyReached(this.targetEnergy);
                } else {
                    boolean bl = keepStepping = keepStepping && !this.state.epsilonReached(this.targetEpsilon) && this.state.isStable(this.stabilityThreshold) && this.state.hasLowEnergies() && !this.state.energyReached(this.targetEnergy);
                }
                if (!keepStepping) {
                    break;
                }
                if (Double.isNaN(this.state.dEnergy) || Double.isNaN(this.state.dScore)) {
                    throw new Error("Can't determine gradient of delta surface. This is a bug.");
                }
                boolean scoreAheadOfEnergy = this.numEnergyConfsEnumerated < this.numScoreConfsEnumerated;
                boolean bl = energySteeperThanScore = this.state.dEnergy <= this.state.dScore;
                if (this.hasEnergyConfs && (scoreAheadOfEnergy && energySteeperThanScore || !this.hasScoreConfs)) {
                    step = Step.Energy;
                } else if (this.hasScoreConfs) {
                    step = Step.Score;
                    double scoringSeconds = Math.max(0.1 / this.state.energyOps, 0.01);
                    numScores = Math.max((int)(scoringSeconds * this.state.scoreOps), 10);
                }
            }
            switch (step) {
                case Energy: {
                    class EnergyResult {
                        ConfSearch.EnergiedConf econf;
                        EnergyCalculator.EnergiedParametricMolecule epmol;
                        BigDecimal scoreWeight;
                        BigDecimal energyWeight;
                        Stopwatch stopwatch = new Stopwatch();

                        EnergyResult(PasteGradientDescentPfunc this$0) {
                        }
                    }
                    ConfSearch.ScoredConf conf = this.energyConfs.nextConf();
                    if (conf != null) {
                        ++this.numEnergyConfsEnumerated;
                    }
                    if (conf == null || conf.getScore() == Double.POSITIVE_INFINITY) {
                        this.hasEnergyConfs = false;
                        keepStepping = false;
                        break;
                    }
                    ++numConfsEnergied;
                    this.ecalc.tasks.submit(() -> {
                        EnergyResult result = new EnergyResult(this);
                        result.stopwatch.start();
                        result.epmol = this.ecalc.calcEnergy(new RCTuple(conf.getAssignments()));
                        result.econf = new ConfSearch.EnergiedConf(conf, result.epmol.energy);
                        if (this.state.sConfs.size() <= numPDBs) {
                            this.state.sConfs.put(result.econf.getEnergy(), conf);
                        }
                        result.scoreWeight = this.bcalc.calc(result.econf.getScore());
                        result.energyWeight = this.bcalc.calc(result.econf.getEnergy());
                        result.stopwatch.stop();
                        return result;
                    }, result -> this.onEnergy(result.epmol, result.econf, result.scoreWeight, result.energyWeight, result.stopwatch.getTimeS()));
                    break;
                }
                case Score: {
                    class ScoreResult {
                        List<BigDecimal> scoreWeights = new ArrayList<BigDecimal>();
                        Stopwatch stopwatch = new Stopwatch();

                        ScoreResult(PasteGradientDescentPfunc this$0) {
                        }
                    }
                    ArrayList<ConfSearch.ScoredConf> confs = new ArrayList<ConfSearch.ScoredConf>();
                    for (int i = 0; i < numScores; ++i) {
                        ConfSearch.ScoredConf conf = this.scoreConfs.nextConf();
                        if (conf != null) {
                            ++this.numScoreConfsEnumerated;
                        }
                        if (conf == null || conf.getScore() == Double.POSITIVE_INFINITY) {
                            this.hasScoreConfs = false;
                            break;
                        }
                        confs.add(conf);
                    }
                    this.ecalc.tasks.submit(() -> {
                        ScoreResult result = new ScoreResult(this);
                        result.stopwatch.start();
                        for (ConfSearch.ScoredConf conf : confs) {
                            result.scoreWeights.add(this.bcalc.calc(conf.getScore()));
                        }
                        result.stopwatch.stop();
                        return result;
                    }, result -> this.onScores(result.scoreWeights, result.stopwatch.getTimeS()));
                    break;
                }
                case None: {
                    keepStepping = false;
                }
            }
        }
        this.ecalc.tasks.waitForFinish();
        this.values.qstar = this.state.getLowerBound();
        this.values.qprime = new BigMath(EWAKStarPartitionFunction.decimalPrecision).set(this.state.getUpperBound()).sub(this.state.getLowerBound()).get();
        if (!this.state.hasLowEnergies()) {
            this.status = PastePartitionFunction.Status.OutOfLowEnergies;
        }
        if (!this.hasEnergyConfs) {
            this.status = PastePartitionFunction.Status.OutOfConformations;
        }
        if (this.state.epsilonReached(this.targetEpsilon)) {
            this.status = PastePartitionFunction.Status.EpsilonReached;
        }
        if (this.useWindowCriterion && this.wtResult != null && this.state.noWindowOverlaps(this.wtResult, this.numEnergyConfsEnumerated)) {
            this.status = PastePartitionFunction.Status.NoWindowOverlap;
        }
        if (this.state.energyReached(this.targetEnergy)) {
            this.status = PastePartitionFunction.Status.EnergyReached;
        }
        if (this.state.numEnergiedConfs == (long)maxNumConfs) {
            this.status = PastePartitionFunction.Status.ConfLimitReached;
        }
        if (!this.state.isStable(this.stabilityThreshold)) {
            this.status = PastePartitionFunction.Status.Unstable;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void onEnergy(EnergyCalculator.EnergiedParametricMolecule epmol, ConfSearch.EnergiedConf econf, BigDecimal scoreWeight, BigDecimal energyWeight, double seconds) {
        PasteGradientDescentPfunc pasteGradientDescentPfunc = this;
        synchronized (pasteGradientDescentPfunc) {
            this.state.energyWeightSum = this.state.energyWeightSum.add(energyWeight);
            this.state.epMols.add(epmol);
            if (this.state.curGMEC >= econf.getEnergy()) {
                this.state.setGMECEnergy(econf.getEnergy());
            }
            this.state.curScore = econf.getScore();
            this.state.lowerScoreWeightSum = this.state.lowerScoreWeightSum.add(scoreWeight);
            ++this.state.numEnergiedConfs;
            this.state.energyOps = 1.0 / seconds;
            if (MathTools.isLessThan(scoreWeight, this.state.minLowerScoreWeight)) {
                this.state.minLowerScoreWeight = scoreWeight;
            }
            double delta = this.state.calcDelta();
            this.state.dEnergy = PasteGradientDescentPfunc.calcSlope(delta, this.state.prevDelta, this.state.dScore);
            this.state.prevDelta = delta;
            this.state.dScore *= 2.0;
            if (this.isReportingProgress) {
                System.out.println(String.format("conf:%4d, score:%12.6f, energy:%12.6f, bounds:[%12e,%12e], delta:%.6f, time:%10s, heapMem:%s", this.state.numEnergiedConfs, econf.getScore(), econf.getEnergy(), this.state.getLowerBound().doubleValue(), this.state.getUpperBound().doubleValue(), this.state.calcDelta(), this.stopwatch.getTime(2), JvmMem.getOldPool()));
            }
            if (this.trace != null) {
                this.trace.step(this.state.numScoredConfs, this.state.numEnergiedConfs, this.state.calcDelta());
            }
        }
        if (this.confListener != null) {
            this.confListener.onConf(econf);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void onScores(List<BigDecimal> scoreWeights, double seconds) {
        PasteGradientDescentPfunc pasteGradientDescentPfunc = this;
        synchronized (pasteGradientDescentPfunc) {
            for (BigDecimal weight : scoreWeights) {
                this.state.upperScoreWeightSum = this.state.upperScoreWeightSum.add(weight);
                if (!MathTools.isLessThan(weight, this.state.minUpperScoreWeight)) continue;
                this.state.minUpperScoreWeight = weight;
            }
            this.state.numScoredConfs += (long)scoreWeights.size();
            this.state.scoreOps = (double)scoreWeights.size() / seconds;
            double delta = this.state.calcDelta();
            this.state.dScore = PasteGradientDescentPfunc.calcSlope(delta, this.state.prevDelta, this.state.dEnergy);
            this.state.prevDelta = delta;
            this.state.dEnergy *= 2.0;
            if (this.trace != null) {
                this.trace.step(this.state.numScoredConfs, this.state.numEnergiedConfs, this.state.calcDelta());
            }
        }
    }

    private static double calcSlope(double delta, double prevDelta, double otherSlope) {
        double slope = delta - prevDelta;
        if (slope >= 0.0) {
            slope = otherSlope / 10.0;
        }
        return slope;
    }

    private static class State {
        ArrayList<EnergyCalculator.EnergiedParametricMolecule> epMols = new ArrayList();
        HashMap<Double, ConfSearch.ScoredConf> sConfs = new HashMap();
        BigDecimal numConfs;
        double curGMEC = Double.POSITIVE_INFINITY;
        double curScore = Double.POSITIVE_INFINITY;
        long numScoredConfs = 0L;
        BigDecimal upperScoreWeightSum = BigDecimal.ZERO;
        BigDecimal minUpperScoreWeight = MathTools.BigPositiveInfinity;
        long numEnergiedConfs = 0L;
        BigDecimal lowerScoreWeightSum = BigDecimal.ZERO;
        BigDecimal energyWeightSum = BigDecimal.ZERO;
        BigDecimal minLowerScoreWeight = MathTools.BigPositiveInfinity;
        public static double constRT = -0.593050165;
        double scoreOps = 100.0;
        double energyOps = 1.0;
        double prevDelta = 1.0;
        double dEnergy = -1.0;
        double dScore = -1.0;

        State(BigInteger numConfs) {
            this.numConfs = new BigDecimal(numConfs);
        }

        public void setGMECEnergy(double newEnergy) {
            this.curGMEC = newEnergy;
        }

        double calcDiff() {
            return this.curScore - this.curGMEC;
        }

        double calcDelta() {
            BigDecimal upperBound = this.getUpperBound();
            if (MathTools.isZero(upperBound) || MathTools.isInf(upperBound)) {
                return 1.0;
            }
            return new BigMath(EWAKStarPartitionFunction.decimalPrecision).set(upperBound).sub(this.getLowerBound()).div(upperBound).get().doubleValue();
        }

        public BigDecimal getLowerBound() {
            return this.energyWeightSum;
        }

        public BigDecimal getUpperBound() {
            return new BigMath(EWAKStarPartitionFunction.decimalPrecision).set(this.numConfs).sub(this.numScoredConfs).mult(this.minUpperScoreWeight).add(this.upperScoreWeightSum).sub(this.lowerScoreWeightSum).add(this.energyWeightSum).get();
        }

        boolean noWindowOverlaps(PastePartitionFunction.Result wt, long numConfs) {
            if (wt != null && numConfs >= 10L) {
                double lowerBoundScore = constRT * Math.log(new BigMath(PastePartitionFunction.decimalPrecision).set(this.getUpperBound()).div(wt.values.calcLowerBound()).get().doubleValue());
                double upperBoundScore = constRT * Math.log(new BigMath(PastePartitionFunction.decimalPrecision).set(this.getLowerBound()).div(wt.values.calcUpperBound()).get().doubleValue());
                if (lowerBoundScore > 0.0) {
                    return true;
                }
                return upperBoundScore < 0.0;
            }
            return false;
        }

        boolean epsilonReached(double targetEpsilon) {
            return this.calcDelta() <= targetEpsilon;
        }

        boolean energyReached(double targetEnergy) {
            return this.calcDiff() >= targetEnergy;
        }

        boolean isStable(BigDecimal stabilityThreshold) {
            return this.numEnergiedConfs <= 0L || stabilityThreshold == null || MathTools.isGreaterThanOrEqual(this.getUpperBound(), stabilityThreshold);
        }

        boolean hasLowEnergies() {
            return MathTools.isGreaterThan(this.minLowerScoreWeight, BigDecimal.ZERO);
        }

        public String toString() {
            return String.format("upper: count %d  sum %e  min %e     lower: count %d  score sum %e  energy sum %e", this.numScoredConfs, this.upperScoreWeightSum, this.minUpperScoreWeight, this.numEnergiedConfs, this.lowerScoreWeightSum, this.energyWeightSum);
        }
    }

    private static enum Step {
        None,
        Score,
        Energy;

    }
}

