/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.energy.approximation;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import edu.duke.cs.osprey.confspace.ParametricMolecule;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.dof.DofInfo;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.energy.EnergyFunction;
import edu.duke.cs.osprey.energy.ResidueInteractions;
import edu.duke.cs.osprey.energy.approximation.ApproximatedObjectiveFunction;
import edu.duke.cs.osprey.energy.approximation.ApproximatorMatrix;
import edu.duke.cs.osprey.energy.approximation.QuadraticApproximator;
import edu.duke.cs.osprey.minimization.Minimizer;
import edu.duke.cs.osprey.minimization.MoleculeObjectiveFunction;
import edu.duke.cs.osprey.minimization.SimpleCCDMinimizer;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.MathTools;
import edu.duke.cs.osprey.tools.Progress;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class ApproximatorMatrixCalculator {
    public final ConfEnergyCalculator confEcalc;
    private int numSamplesPerParam = 10;
    private ApproximatorType type = ApproximatorType.Quadratic;
    private File cacheFile = null;

    public ApproximatorMatrixCalculator(ConfEnergyCalculator confEcalc) {
        int n = confEcalc.confSpace.positions.size();
        int maxTuplesPerConf = n + n * (n - 1) / 2;
        this.confEcalc = new ConfEnergyCalculator.Builder(confEcalc.confSpace, confEcalc.ecalc).setApproximationErrorBudget(confEcalc.approximationErrorBudget / (double)maxTuplesPerConf).setReferenceEnergies(confEcalc.eref).setEnergyPartition(confEcalc.epart).build();
    }

    public ApproximatorMatrixCalculator setNumSamplesPerParam(int val) {
        this.numSamplesPerParam = val;
        return this;
    }

    public ApproximatorMatrixCalculator setApproximatorType(ApproximatorType val) {
        this.type = val;
        return this;
    }

    public ApproximatorMatrixCalculator setCacheFile(File val) {
        this.cacheFile = val;
        return this;
    }

    public ApproximatorMatrix calc() {
        ApproximatorMatrix amat = new ApproximatorMatrix(this.confEcalc.confSpace);
        if (this.cacheFile != null && this.cacheFile.exists()) {
            amat.readFrom(this.cacheFile);
            Log.log("read Approximator Matrix from file: %s", this.cacheFile.getAbsolutePath());
            return amat;
        }
        int numRCs = this.confEcalc.confSpace.getNumResConfs();
        Progress progress2 = new Progress(numRCs * (1 + this.confEcalc.confSpace.shellResNumbers.size()));
        Log.log("calculating %d approximators for %d RCs ...", progress2.getTotalWork(), numRCs);
        for (SimpleConfSpace.Position pos1 : this.confEcalc.confSpace.positions) {
            for (SimpleConfSpace.ResidueConf rc1 : pos1.resConfs) {
                this.confEcalc.tasks.submit(() -> this.calc(pos1, rc1), approximator -> {
                    amat.set(pos1, rc1, (ApproximatedObjectiveFunction.Approximator.Addable)approximator);
                    progress2.incrementProgress();
                });
                for (String resNum : this.confEcalc.confSpace.shellResNumbers) {
                    this.confEcalc.tasks.submit(() -> this.calc(pos1, rc1, resNum), approximator -> {
                        amat.set(pos1, rc1, resNum, (ApproximatedObjectiveFunction.Approximator.Addable)approximator);
                        progress2.incrementProgress();
                    });
                }
            }
        }
        for (SimpleConfSpace.Position pos1 : this.confEcalc.confSpace.positions) {
            for (SimpleConfSpace.ResidueConf rc1 : pos1.resConfs) {
                for (SimpleConfSpace.Position pos2 : this.confEcalc.confSpace.positions.subList(0, pos1.index)) {
                    for (SimpleConfSpace.ResidueConf rc2 : pos2.resConfs) {
                        this.confEcalc.tasks.submit(() -> this.calc(pos1, rc1, pos2, rc2), approximator -> {
                            amat.set(pos1, rc1, pos2, rc2, (ApproximatedObjectiveFunction.Approximator.Addable)approximator);
                            progress2.incrementProgress();
                        });
                    }
                }
            }
        }
        this.confEcalc.tasks.waitForFinish();
        if (this.cacheFile != null) {
            amat.writeTo(this.cacheFile);
            Log.log("wrote Approximator Matrix to file: %s", this.cacheFile.getAbsolutePath());
        }
        return amat;
    }

    public ApproximatedObjectiveFunction.Approximator.Addable calc(SimpleConfSpace.Position pos, SimpleConfSpace.ResidueConf rc, String fixedResNum) {
        ResidueInteractions inters = new ResidueInteractions();
        inters.addPair(pos.resNum, fixedResNum);
        return this.calc(new RCTuple(pos.index, rc.index), inters);
    }

    public ApproximatedObjectiveFunction.Approximator.Addable calc(SimpleConfSpace.Position pos, SimpleConfSpace.ResidueConf rc) {
        ResidueInteractions inters = new ResidueInteractions();
        inters.addPair(pos.resNum, pos.resNum);
        return this.calc(new RCTuple(pos.index, rc.index), inters);
    }

    public ApproximatedObjectiveFunction.Approximator.Addable calc(SimpleConfSpace.Position pos1, SimpleConfSpace.ResidueConf rc1, SimpleConfSpace.Position pos2, SimpleConfSpace.ResidueConf rc2) {
        ResidueInteractions inters = new ResidueInteractions();
        inters.addPair(pos1.resNum, pos2.resNum);
        return this.calc(new RCTuple(pos1.index, rc1.index, pos2.index, rc2.index), inters);
    }

    public ApproximatedObjectiveFunction.Approximator.Addable calc(RCTuple tuple, ResidueInteractions inters) {
        ParametricMolecule pmol = this.confEcalc.confSpace.makeMolecule(tuple);
        try (EnergyFunction ff = this.confEcalc.ecalc.makeEnergyFunction(pmol, inters);){
            MoleculeObjectiveFunction f = new MoleculeObjectiveFunction(pmol, ff);
            DofInfo dofInfo = this.confEcalc.confSpace.makeDofInfo(tuple);
            QuadraticApproximator approximator = switch (this.type) {
                case ApproximatorType.Quadratic -> new QuadraticApproximator(dofInfo.ids, dofInfo.counts);
                default -> throw new IllegalArgumentException("unknown approximator type: " + String.valueOf((Object)this.type));
            };
            if (pmol.dofs.isEmpty()) {
                approximator.train(ff.getEnergy());
            } else {
                int numSamples = 1 + approximator.numParams() * this.numSamplesPerParam;
                Random rand = new Random(tuple.hashCode());
                approximator.train(this.sampleRandomly(pmol, f, numSamples, rand), this.sampleRandomly(pmol, f, numSamples, rand));
            }
            QuadraticApproximator quadraticApproximator = approximator;
            return quadraticApproximator;
        }
    }

    private List<Minimizer.Result> sampleRandomly(ParametricMolecule pmol, MoleculeObjectiveFunction f, int numSamples, Random rand) {
        ArrayList<Minimizer.Result> samples = new ArrayList<Minimizer.Result>(numSamples);
        samples.add(new SimpleCCDMinimizer(f).minimizeFromCenter());
        for (int i = 1; i < numSamples; ++i) {
            DoubleMatrix1D x = DoubleFactory1D.dense.make(pmol.dofBounds.size());
            for (int d = 0; d < pmol.dofBounds.size(); ++d) {
                double min = pmol.dofBounds.getMin(d);
                double max = pmol.dofBounds.getMax(d);
                x.set(d, min + rand.nextDouble() * (max - min));
            }
            samples.add(new Minimizer.Result(x, f.getValue(x)));
        }
        samples.add(new SimpleCCDMinimizer(f).minimizeFromCenter());
        return samples;
    }

    private List<Minimizer.Result> sampleDensely(ParametricMolecule pmol, MoleculeObjectiveFunction f, int numSamplesPerDof) {
        int numDims = pmol.dofBounds.size();
        int[] dims = new int[numDims];
        Arrays.fill(dims, numSamplesPerDof);
        int numSamples = numSamplesPerDof;
        for (int i = 1; i < numDims; ++i) {
            numSamples *= numSamplesPerDof;
        }
        ArrayList<Minimizer.Result> samples = new ArrayList<Minimizer.Result>(++numSamples);
        samples.add(new SimpleCCDMinimizer(f).minimizeFromCenter());
        for (int[] p : new MathTools.GridIterable(dims)) {
            DoubleMatrix1D x = DoubleFactory1D.dense.make(numDims);
            for (int d = 0; d < numDims; ++d) {
                double min = pmol.dofBounds.getMin(d);
                double max = pmol.dofBounds.getMax(d);
                double xd = min + (max - min) * (double)p[d] / (double)(dims[d] - 1);
                x.set(d, xd);
            }
            samples.add(new Minimizer.Result(x, f.getValue(x)));
        }
        return samples;
    }

    public static enum ApproximatorType {
        Quadratic;

    }
}

