/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.partcr;

import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.confspace.RC;
import edu.duke.cs.osprey.confspace.SearchProblem;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.SimpleEnergyCalculator;
import edu.duke.cs.osprey.energy.forcefield.ForcefieldParams;
import edu.duke.cs.osprey.partcr.SplitWorld;
import edu.duke.cs.osprey.partcr.pickers.ConfPicker;
import edu.duke.cs.osprey.partcr.pickers.WalkingConfPicker;
import edu.duke.cs.osprey.partcr.scorers.RCScorer;
import edu.duke.cs.osprey.partcr.scorers.VolumeRCScorer;
import edu.duke.cs.osprey.partcr.splitters.NAryRCSplitter;
import edu.duke.cs.osprey.partcr.splitters.RCSplitter;
import edu.duke.cs.osprey.tools.TimeFormatter;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;

public class PartCR {
    private SearchProblem search;
    private double Ew;
    private SimpleEnergyCalculator ecalc;
    private List<ConfSearch.ScoredConf> confs;
    private SplitWorld splitWorld;
    private ConfPicker picker;
    private RCScorer scorer;
    private RCSplitter splitter;
    private double bestMinimizedEnergy;
    private long minimizationNs;
    private long iterationNs;
    private int numIterations;

    public PartCR(SearchProblem search2, ForcefieldParams ffparams, double Ew, List<ConfSearch.ScoredConf> confs) {
        this.search = search2;
        this.Ew = Ew;
        this.ecalc = new SimpleEnergyCalculator.Cpu(ffparams, search2.confSpace, search2.shellResidues);
        this.confs = confs;
        this.splitWorld = new SplitWorld(search2, ffparams);
        this.picker = new WalkingConfPicker();
        this.scorer = new VolumeRCScorer();
        this.splitter = new NAryRCSplitter();
        this.bestMinimizedEnergy = Double.POSITIVE_INFINITY;
        this.minimizationNs = 0L;
        this.iterationNs = 0L;
        this.numIterations = 0;
    }

    public void setPicker(ConfPicker val) {
        this.picker = val;
    }

    public void setScorer(RCScorer val) {
        this.scorer = val;
    }

    public void setSplitter(RCSplitter val) {
        this.splitter = val;
    }

    public List<ConfSearch.ScoredConf> getConfs() {
        return this.confs;
    }

    public long getAvgMinimizationTimeNs() {
        return this.minimizationNs / (long)this.numIterations;
    }

    public long getAvgIterationTimeNs() {
        return this.iterationNs / (long)this.numIterations;
    }

    public void autoIterate() {
        this.autoIterate(3);
    }

    public void autoIterate(int maxNumStrikes) {
        int initialNumConfs;
        int numConfs = initialNumConfs = this.confs.size();
        long startTimeNs = System.nanoTime();
        int numStrikes = 0;
        while (true) {
            this.iterate();
            int targetPruning = (int)(this.getAvgIterationTimeNs() / this.getAvgMinimizationTimeNs());
            int numPruned = numConfs - this.confs.size();
            numConfs = this.confs.size();
            if (numConfs <= targetPruning) {
                System.out.println(String.format("Pruned %d/%d conformations. Only %d conformations left, time to stop.", numPruned, targetPruning, numConfs));
                break;
            }
            if (numPruned < targetPruning) {
                boolean shouldStop = ++numStrikes >= maxNumStrikes;
                System.out.println(String.format("Pruned %d/%d conformations. %d/%d strikes. %s", numPruned, targetPruning, numStrikes, maxNumStrikes, shouldStop ? " time to stop." : ""));
                if (!shouldStop) continue;
                break;
            }
            System.out.println(String.format("Pruned %d/%d conformations. keep iterating", numPruned, targetPruning));
        }
        long initialTimeNs = this.getAvgMinimizationTimeNs() * (long)initialNumConfs;
        long afterTimeNs = this.getAvgMinimizationTimeNs() * (long)numConfs;
        long diffTimeNs = System.nanoTime() - startTimeNs;
        long savingsNs = initialTimeNs - afterTimeNs;
        long netSavingsNs = savingsNs - diffTimeNs;
        System.out.println(String.format("\nPartCR took %s and saved %s of minimization time for a net %s", TimeFormatter.format(diffTimeNs, 1), TimeFormatter.format(savingsNs, 1), String.format("%s of %s", netSavingsNs > 0L ? "SAVINGS" : "LOSS", TimeFormatter.format(Math.abs(netSavingsNs), 1))));
        System.out.println(String.format("PartCR pruned %.1f%% of low-energy conformations", 100.0 * (double)(initialNumConfs - numConfs) / (double)initialNumConfs));
    }

    public void iterate() {
        RC rcObj;
        ++this.numIterations;
        System.out.println("\nPartCR iteration " + this.numIterations);
        long startIterNs = System.nanoTime();
        int numPos = this.splitWorld.getSearchProblem().confSpace.numPos;
        ConfSearch.ScoredConf pickedConf = this.picker.pick(this.confs);
        ConfSearch.ScoredConf translatedPickedConf = this.splitWorld.translateConf(pickedConf);
        System.out.println("minimizing conformation...");
        long startMinNs = System.nanoTime();
        ConfSearch.EnergiedConf analyzeConf = new ConfSearch.EnergiedConf(translatedPickedConf, this.calcMinimizedEnergy(pickedConf.getAssignments()));
        long diffMinNs = System.nanoTime() - startMinNs;
        this.minimizationNs += diffMinNs;
        this.bestMinimizedEnergy = Math.min(this.bestMinimizedEnergy, analyzeConf.getEnergy());
        if (this.numIterations == 1) {
            System.out.println(String.format("initial conformations: %d, estimated time to enumerate: %s", this.confs.size(), TimeFormatter.format(this.getAvgMinimizationTimeNs() * (long)this.confs.size(), 1)));
        }
        double boundEnergyCheck = this.search.emat.getConstTerm();
        double minimizedEnergyCheck = 0.0;
        TreeMap<Double, Integer> positionsByScore = new TreeMap<Double, Integer>();
        for (int pos = 0; pos < numPos; ++pos) {
            rcObj = this.splitWorld.getSearchProblem().confSpace.posFlex.get((int)pos).RCs.get(analyzeConf.getAssignments()[pos]);
            double posBoundEnergy = this.calcPosBoundEnergy(analyzeConf.getAssignments(), pos);
            double posMinimizedEnergy = this.calcPosMinimizedEnergy(pos);
            boundEnergyCheck += posBoundEnergy;
            minimizedEnergyCheck += posMinimizedEnergy;
            double err = posMinimizedEnergy - posBoundEnergy;
            double score = this.scorer.calcScore(this.splitWorld, rcObj, err);
            positionsByScore.put(score, pos);
        }
        this.checkEnergy(boundEnergyCheck, analyzeConf.getScore());
        this.checkEnergy(minimizedEnergyCheck, analyzeConf.getEnergy());
        System.out.println("splitting residue conformation...");
        int splitPos = (Integer)positionsByScore.lastEntry().getValue();
        rcObj = this.splitWorld.getRC(splitPos, analyzeConf.getAssignments()[splitPos]);
        List<RC> splitRCs = this.splitter.split(splitPos, rcObj);
        this.splitWorld.replaceRc(splitPos, rcObj, splitRCs);
        System.out.println("calculating energies and pruning conformations...");
        this.splitWorld.resizeMatrices();
        Iterator<ConfSearch.ScoredConf> iter = this.confs.iterator();
        while (iter.hasNext()) {
            ConfSearch.ScoredConf conf = iter.next();
            double improvedBoundEnergy = this.splitWorld.translateConf(conf).getScore();
            if (!(improvedBoundEnergy > this.bestMinimizedEnergy + this.Ew)) continue;
            iter.remove();
        }
        long diffIterNs = System.nanoTime() - startIterNs;
        this.iterationNs += diffIterNs;
        System.out.println(String.format("finished iteration in %s", TimeFormatter.format(diffIterNs, 1)));
        System.out.println(String.format("conformations remaining: %d, estimated time to enumerate: %s", this.confs.size(), TimeFormatter.format(this.getAvgMinimizationTimeNs() * (long)this.confs.size(), 1)));
    }

    private void checkEnergy(double observed, double expected) {
        double absErr = Math.abs(observed - expected);
        double relErr = absErr / Math.abs(expected);
        double Epsilon = 1.0E-10;
        if (relErr > 1.0E-10) {
            throw new Error(String.format("Energies don't match. This is a bug!\n\texpected: %f\n\tobserved: %f\n\tabs err:  %f\n\trel err:  %f", expected, observed, absErr, relErr));
        }
    }

    private double calcMinimizedEnergy(int[] conf) {
        return this.search.minimizedEnergy(conf);
    }

    private double calcPosBoundEnergy(int[] conf, int pos1) {
        EnergyMatrix emat = this.splitWorld.getSearchProblem().emat;
        int numPos = emat.getNumPos();
        int rc1 = conf[pos1];
        double energy = emat.getOneBody(pos1, rc1);
        for (int pos2 = 0; pos2 < numPos; ++pos2) {
            if (pos1 == pos2) continue;
            int rc2 = conf[pos2];
            energy += emat.getPairwise(pos1, rc1, pos2, rc2) / 2.0;
        }
        return energy;
    }

    private double calcPosMinimizedEnergy(int pos1) {
        double energy = 0.0;
        energy += this.ecalc.makeSingleEfunc(pos1).getEnergy();
        for (int pos2 = 0; pos2 < this.ecalc.confSpace.numPos; ++pos2) {
            if (pos2 == pos1) continue;
            energy += this.ecalc.makePairEfunc(pos1, pos2).getEnergy() / 2.0;
        }
        return energy;
    }
}

