/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.ematrix;

import edu.duke.cs.osprey.confspace.ConfSpaceIteration;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.SimpleReferenceEnergies;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.energy.EnergyCalculator;
import edu.duke.cs.osprey.energy.ResidueInteractions;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.ObjectIO;
import edu.duke.cs.osprey.tools.Progress;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

public class SimplerEnergyMatrixCalculator {
    public final ConfEnergyCalculator confEcalc;
    public final File cacheFile;
    public final Double tripleCorrectionThreshold;
    public final Double quadCorrectionThreshold;
    public final boolean calcConstantTerm;

    private SimplerEnergyMatrixCalculator(ConfEnergyCalculator confEcalc, File cacheFile, Double tripleCorrectionThreshold, Double quadCorrectionThreshold, boolean calcConstantTerm) {
        this.confEcalc = confEcalc;
        this.cacheFile = cacheFile;
        this.tripleCorrectionThreshold = tripleCorrectionThreshold;
        this.quadCorrectionThreshold = quadCorrectionThreshold;
        this.calcConstantTerm = calcConstantTerm;
    }

    public EnergyMatrix calcEnergyMatrix() {
        if (this.cacheFile != null) {
            return ObjectIO.readOrMake(this.cacheFile, EnergyMatrix.class, "energy matrix", emat -> emat.matches(this.confEcalc.confSpace), context -> this.reallyCalcEnergyMatrix());
        }
        return this.reallyCalcEnergyMatrix();
    }

    private EnergyMatrix reallyCalcEnergyMatrix() {
        class Batch {
            List<RCTuple> fragments = new ArrayList<RCTuple>();
            int cost = 0;
            final /* synthetic */ int val$constCost;
            final /* synthetic */ int val$singleCost;
            final /* synthetic */ int val$pairCost;
            final /* synthetic */ EnergyMatrix val$emat;
            final /* synthetic */ Progress val$progress;

            Batch() {
                this.val$constCost = n;
                this.val$singleCost = n2;
                this.val$pairCost = n3;
                this.val$emat = energyMatrix;
                this.val$progress = progress2;
            }

            void addConst() {
                this.fragments.add(new RCTuple());
                this.cost += this.val$constCost;
            }

            void addSingle(int pos, int rc) {
                this.fragments.add(new RCTuple(pos, rc));
                this.cost += this.val$singleCost;
            }

            void addPair(int pos1, int rc1, int pos2, int rc2) {
                this.fragments.add(new RCTuple(pos1, rc1, pos2, rc2));
                this.cost += this.val$pairCost;
            }

            void submitTask() {
                SimplerEnergyMatrixCalculator.this.confEcalc.tasks.submit(() -> {
                    ArrayList<Double> energies = new ArrayList<Double>();
                    for (RCTuple frag : this.fragments) {
                        double energy;
                        if (SimplerEnergyMatrixCalculator.isParametricallyIncompatible(SimplerEnergyMatrixCalculator.this.confEcalc, frag)) {
                            energy = Double.POSITIVE_INFINITY;
                        } else {
                            switch (frag.size()) {
                                case 0: {
                                    energy = SimplerEnergyMatrixCalculator.this.confEcalc.calcShellEnergy((RCTuple)frag).energy;
                                    break;
                                }
                                case 1: {
                                    energy = SimplerEnergyMatrixCalculator.this.confEcalc.calcSingleEnergy((RCTuple)frag).energy;
                                    break;
                                }
                                case 2: {
                                    energy = SimplerEnergyMatrixCalculator.this.confEcalc.calcPairEnergy((RCTuple)frag).energy;
                                    break;
                                }
                                default: {
                                    energy = SimplerEnergyMatrixCalculator.this.confEcalc.calcEnergy((RCTuple)frag).energy;
                                }
                            }
                        }
                        energies.add(energy);
                    }
                    return energies;
                }, energies -> {
                    for (int i = 0; i < this.fragments.size(); ++i) {
                        RCTuple frag = this.fragments.get(i);
                        if (frag.size() == 0) {
                            this.val$emat.setConstTerm((Double)energies.get(i));
                            continue;
                        }
                        if (frag.size() == 1) {
                            this.val$emat.setOneBody((int)frag.pos.get(0), (int)frag.RCs.get(0), (Double)energies.get(i));
                            continue;
                        }
                        if (frag.size() == 2) {
                            this.val$emat.setPairwise((int)frag.pos.get(0), (int)frag.RCs.get(0), (int)frag.pos.get(1), (int)frag.RCs.get(1), (Double)energies.get(i));
                            continue;
                        }
                        this.val$emat.setTuple(frag, (Double)energies.get(i));
                    }
                    this.val$progress.incrementProgress(this.cost);
                });
            }
        }
        final EnergyMatrix emat = new EnergyMatrix(this.confEcalc.confSpaceIteration());
        final int constCost = this.confEcalc.makeShellInters().size();
        final int singleCost = emat.getNumPos() <= 0 ? 0 : this.confEcalc.makeSingleInters(0, 0).size();
        final int pairCost = emat.getNumPos() <= 0 ? 0 : this.confEcalc.makePairInters(0, 0, 0, 0).size();
        int numConst = this.calcConstantTerm ? 1 : 0;
        int numSingles = this.confEcalc.confSpaceIteration().countSingles();
        int numPairs = this.confEcalc.confSpaceIteration().countPairs();
        final Progress progress2 = new Progress(numSingles * singleCost + numPairs * pairCost);
        int CostThreshold = 100;
        class Batcher {
            Batch batch = null;

            Batcher() {
            }

            Batch getBatch() {
                if (this.batch == null) {
                    this.batch = new Batch(SimplerEnergyMatrixCalculator.this, constCost, singleCost, pairCost, emat, progress2);
                }
                return this.batch;
            }

            void submitIfFull() {
                if (this.batch != null && this.batch.cost >= 100) {
                    this.submit();
                }
            }

            void submit() {
                if (this.batch != null) {
                    this.batch.submitTask();
                    this.batch = null;
                }
            }
        }
        Batcher batcher = new Batcher();
        Log.log("Calculating energy matrix with %d entries", numConst + numSingles + numPairs);
        if (this.calcConstantTerm) {
            batcher.getBatch().addConst();
            batcher.submitIfFull();
        }
        for (int pos1 = 0; pos1 < emat.getNumPos(); ++pos1) {
            for (int rc1 = 0; rc1 < emat.getNumConfAtPos(pos1); ++rc1) {
                batcher.getBatch().addSingle(pos1, rc1);
                batcher.submitIfFull();
                for (int pos2 = 0; pos2 < pos1; ++pos2) {
                    for (int rc2 = 0; rc2 < emat.getNumConfAtPos(pos2); ++rc2) {
                        batcher.getBatch().addPair(pos1, rc1, pos2, rc2);
                        batcher.submitIfFull();
                    }
                }
            }
        }
        batcher.submit();
        this.confEcalc.tasks.waitForFinish();
        if (this.quadCorrectionThreshold != null) {
            this.calcQuadCorrections(emat);
        } else if (this.tripleCorrectionThreshold != null) {
            this.calcTripleCorrections(emat);
        }
        return emat;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void calcTripleCorrections(EnergyMatrix emat) {
        Progress progress2 = new Progress(this.confEcalc.confSpace.getNumResConfTriples());
        Log.log("calculating triple corrections for up to %d triples", progress2.getTotalWork());
        int[] numCorrections = new int[]{0};
        for (int pos1 = 0; pos1 < emat.getNumPos(); ++pos1) {
            for (int rc1 = 0; rc1 < emat.getNumConfAtPos(pos1); ++rc1) {
                for (int pos2 = 0; pos2 < pos1; ++pos2) {
                    for (int rc2 = 0; rc2 < emat.getNumConfAtPos(pos2); ++rc2) {
                        for (int pos3 = 0; pos3 < pos2; ++pos3) {
                            for (int rc3 = 0; rc3 < emat.getNumConfAtPos(pos3); ++rc3) {
                                if (emat.getOneBody(pos1, rc1) > this.tripleCorrectionThreshold || emat.getOneBody(pos2, rc2) > this.tripleCorrectionThreshold || emat.getOneBody(pos3, rc3) > this.tripleCorrectionThreshold || emat.getPairwise(pos1, rc1, pos2, rc2) > this.tripleCorrectionThreshold || emat.getPairwise(pos1, rc1, pos3, rc3) > this.tripleCorrectionThreshold || emat.getPairwise(pos2, rc2, pos3, rc3) > this.tripleCorrectionThreshold) {
                                    Progress progress3 = progress2;
                                    synchronized (progress3) {
                                        progress2.incrementProgress();
                                        continue;
                                    }
                                }
                                RCTuple triple = new RCTuple(pos3, rc3, pos2, rc2, pos1, rc1);
                                if (SimplerEnergyMatrixCalculator.isParametricallyIncompatible(this.confEcalc, triple)) {
                                    Progress progress4 = progress2;
                                    synchronized (progress4) {
                                        progress2.incrementProgress();
                                        continue;
                                    }
                                }
                                ResidueInteractions inters = this.confEcalc.makeTripleCorrectionInters(pos1, rc1, pos2, rc2, pos3, rc3);
                                double tripleEnergyOffset = this.confEcalc.epart.offsetTripleEnergy(pos1, rc1, pos2, rc2, pos3, rc3, emat);
                                this.confEcalc.tasks.submit(() -> this.confEcalc.calcEnergy((RCTuple)triple, (ResidueInteractions)inters).energy, tripleEnergy -> {
                                    double correction = tripleEnergy - tripleEnergyOffset;
                                    if (correction > 0.0) {
                                        emat.setTuple(triple, correction);
                                        numCorrections[0] = numCorrections[0] + 1;
                                    }
                                    Progress progress3 = progress2;
                                    synchronized (progress3) {
                                        progress2.incrementProgress();
                                    }
                                });
                            }
                        }
                    }
                }
            }
        }
        this.confEcalc.tasks.waitForFinish();
        Log.log("calculated %d/%d useful triple corrections", numCorrections[0], progress2.getTotalWork());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void calcQuadCorrections(EnergyMatrix emat) {
        Progress progress2 = new Progress(this.confEcalc.confSpace.getNumResConfQuads());
        Log.log("calculating quad corrections for up to %d quads", progress2.getTotalWork());
        int[] numCorrections = new int[]{0};
        for (int pos1 = 0; pos1 < emat.getNumPos(); ++pos1) {
            for (int rc1 = 0; rc1 < emat.getNumConfAtPos(pos1); ++rc1) {
                for (int pos2 = 0; pos2 < pos1; ++pos2) {
                    for (int rc2 = 0; rc2 < emat.getNumConfAtPos(pos2); ++rc2) {
                        for (int pos3 = 0; pos3 < pos2; ++pos3) {
                            for (int rc3 = 0; rc3 < emat.getNumConfAtPos(pos3); ++rc3) {
                                for (int pos4 = 0; pos4 < pos3; ++pos4) {
                                    for (int rc4 = 0; rc4 < emat.getNumConfAtPos(pos4); ++rc4) {
                                        if (emat.getOneBody(pos1, rc1) > this.quadCorrectionThreshold || emat.getOneBody(pos2, rc2) > this.quadCorrectionThreshold || emat.getOneBody(pos3, rc3) > this.quadCorrectionThreshold || emat.getOneBody(pos4, rc4) > this.quadCorrectionThreshold || emat.getPairwise(pos1, rc1, pos2, rc2) > this.quadCorrectionThreshold || emat.getPairwise(pos1, rc1, pos3, rc3) > this.quadCorrectionThreshold || emat.getPairwise(pos1, rc1, pos4, rc4) > this.quadCorrectionThreshold || emat.getPairwise(pos2, rc2, pos3, rc3) > this.quadCorrectionThreshold || emat.getPairwise(pos2, rc2, pos4, rc4) > this.quadCorrectionThreshold || emat.getPairwise(pos3, rc3, pos4, rc4) > this.quadCorrectionThreshold) {
                                            Progress progress3 = progress2;
                                            synchronized (progress3) {
                                                progress2.incrementProgress();
                                                continue;
                                            }
                                        }
                                        RCTuple quad = new RCTuple(pos4, rc4, pos3, rc3, pos2, rc2, pos1, rc1);
                                        if (SimplerEnergyMatrixCalculator.isParametricallyIncompatible(this.confEcalc, quad)) {
                                            Progress progress4 = progress2;
                                            synchronized (progress4) {
                                                progress2.incrementProgress();
                                                continue;
                                            }
                                        }
                                        ResidueInteractions inters = this.confEcalc.makeQuadCorrectionInters(pos1, rc1, pos2, rc2, pos3, rc3, pos4, rc4);
                                        double quadEnergyOffset = this.confEcalc.epart.offsetQuadEnergy(pos1, rc1, pos2, rc2, pos3, rc3, pos4, rc4, emat);
                                        this.confEcalc.tasks.submit(() -> this.confEcalc.calcEnergy((RCTuple)quad, (ResidueInteractions)inters).energy, quadEnergy -> {
                                            double correction = quadEnergy - quadEnergyOffset;
                                            if (correction > 0.0) {
                                                emat.setTuple(quad, correction);
                                                numCorrections[0] = numCorrections[0] + 1;
                                            }
                                            Progress progress3 = progress2;
                                            synchronized (progress3) {
                                                progress2.incrementProgress();
                                            }
                                        });
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        this.confEcalc.tasks.waitForFinish();
        Log.log("calculated %d/%d useful quad corrections", numCorrections[0], progress2.getTotalWork());
    }

    public SimpleReferenceEnergies calcReferenceEnergies() {
        SimpleReferenceEnergies eref = new SimpleReferenceEnergies();
        ConfSpaceIteration confSpace = this.confEcalc.confSpaceIteration();
        Progress progress2 = new Progress(confSpace.countSingles());
        System.out.println("Calculating reference energies for " + progress2.getTotalWork() + " residue confs...");
        for (int posi = 0; posi < confSpace.numPos(); ++posi) {
            int confi = 0;
            while (confi < confSpace.numConf(posi)) {
                String resType = confSpace.confType(posi, confi);
                int fposi = posi;
                int fconfi = confi++;
                this.confEcalc.tasks.submit(() -> this.confEcalc.calcIntraEnergy(fposi, fconfi), epmol -> {
                    Double e = eref.get(fposi, resType);
                    if (e == null || epmol.energy < e) {
                        e = epmol.energy;
                    }
                    eref.set(fposi, resType, e);
                    progress2.incrementProgress();
                });
            }
        }
        this.confEcalc.tasks.waitForFinish();
        return eref;
    }

    private static boolean isParametricallyIncompatible(ConfEnergyCalculator confEcalc, RCTuple tuple) {
        if (confEcalc.confSpace == null) {
            return false;
        }
        SimpleConfSpace confSpace = confEcalc.confSpace;
        for (int i1 = 0; i1 < tuple.size(); ++i1) {
            SimpleConfSpace.ResidueConf rc1 = SimplerEnergyMatrixCalculator.getRC(confSpace, tuple, i1);
            for (int i2 = 0; i2 < i1; ++i2) {
                SimpleConfSpace.ResidueConf rc2 = SimplerEnergyMatrixCalculator.getRC(confSpace, tuple, i2);
                if (SimplerEnergyMatrixCalculator.isPairParametricallyCompatible(rc1, rc2)) continue;
                return true;
            }
        }
        return false;
    }

    private static SimpleConfSpace.ResidueConf getRC(SimpleConfSpace confSpace, RCTuple tuple, int index) {
        return confSpace.positions.get((int)tuple.pos.get((int)index).intValue()).resConfs.get(tuple.RCs.get(index));
    }

    private static boolean isPairParametricallyCompatible(SimpleConfSpace.ResidueConf rc1, SimpleConfSpace.ResidueConf rc2) {
        for (String dofName : rc1.dofBounds.keySet()) {
            if (!rc2.dofBounds.containsKey(dofName)) continue;
            double[] interval1 = rc1.dofBounds.get(dofName);
            double[] interval2 = rc2.dofBounds.get(dofName);
            for (int a = 0; a < 2; ++a) {
                if (!(Math.abs(interval1[a] - interval2[a]) > 1.0E-8)) continue;
                return false;
            }
        }
        return true;
    }

    public static class Builder {
        private ConfEnergyCalculator confEcalc;
        private File cacheFile = null;
        private Double tripleCorrectionThreshold = null;
        private Double quadCorrectionThreshold = null;
        private boolean calcConstantTerm = false;

        @Deprecated
        public Builder(SimpleConfSpace confSpace, EnergyCalculator ecalc) {
            this(new ConfEnergyCalculator.Builder(confSpace, ecalc).build());
        }

        public Builder(ConfEnergyCalculator confEcalc) {
            this.confEcalc = confEcalc;
        }

        public Builder setCacheFile(File val) {
            this.cacheFile = val;
            return this;
        }

        public Builder setTripleCorrectionThreshold(Double val) {
            this.tripleCorrectionThreshold = val;
            return this;
        }

        public Builder setQuadCorrectionThreshold(Double val) {
            this.quadCorrectionThreshold = val;
            return this;
        }

        public Builder setCalcConstantTerm(boolean val) {
            this.calcConstantTerm = val;
            return this;
        }

        public SimplerEnergyMatrixCalculator build() {
            return new SimplerEnergyMatrixCalculator(this.confEcalc, this.cacheFile, this.tripleCorrectionThreshold, this.quadCorrectionThreshold, this.calcConstantTerm);
        }
    }
}

