/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.ematrix.epic;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.jet.math.Functions;
import edu.duke.cs.osprey.ematrix.epic.EPICSettings;
import edu.duke.cs.osprey.ematrix.epic.EPoly;
import edu.duke.cs.osprey.ematrix.epic.EPolyPC;
import edu.duke.cs.osprey.ematrix.epic.FitParams;
import edu.duke.cs.osprey.ematrix.epic.GaussianLowEnergySampler;
import edu.duke.cs.osprey.ematrix.epic.SAPE;
import edu.duke.cs.osprey.ematrix.epic.SeriesFitter;
import edu.duke.cs.osprey.minimization.MoleculeModifierAndScorer;
import edu.duke.cs.osprey.minimization.ObjectiveFunction;
import java.util.ArrayList;
import java.util.Arrays;

public class EPICFitter {
    EPICSettings es;
    public int numDOFs;
    DoubleMatrix1D DOFmax;
    DoubleMatrix1D DOFmin;
    DoubleMatrix1D center;
    double minE;
    public EPoly PCTemplate = null;
    MoleculeModifierAndScorer objFcn;
    static int sampPerParam = 10;

    public EPICFitter(MoleculeModifierAndScorer mof, EPICSettings eset, DoubleMatrix1D cen, double me) {
        this.objFcn = mof;
        this.numDOFs = mof.getNumDOFs();
        this.DOFmax = mof.getConstraints()[1];
        this.DOFmin = mof.getConstraints()[0];
        this.es = eset;
        this.center = cen;
        this.minE = me;
    }

    public EPoly doFit(FitParams fp) {
        int numParams = SeriesFitter.getNumParams(this.numDOFs, false, fp.order);
        int numSamples = sampPerParam * numParams;
        EPolyPC PCFit = null;
        if (fp.PCOrder > fp.order) {
            PCFit = new EPolyPC(this.PCTemplate, fp.order, fp.PCOrder, fp.PCFac);
            int numPCs = SeriesFitter.countTrue(PCFit.isPC);
            fp.numPCParams = 0;
            for (int n = fp.order + 1; n <= fp.PCOrder; ++n) {
                fp.numPCParams += SeriesFitter.getNumParamsForOrder(numPCs, n);
            }
            numSamples = sampPerParam * (numParams += fp.numPCParams);
        }
        if (numParams > 2000) {
            System.out.println("ABORTING EPICFITTER.DOFIT BECAUSE THERE ARE TOO MANY PARAMETERS: " + numParams);
            return null;
        }
        DoubleMatrix1D[] sampRel = new DoubleMatrix1D[numSamples];
        DoubleMatrix1D[] sampAbs = new DoubleMatrix1D[numSamples];
        double[] trueVal = new double[numSamples];
        this.generateSamples(numSamples, sampRel, sampAbs, trueVal, numSamples / 2);
        boolean allBelowCutoff = true;
        for (int s = 0; s < numSamples; ++s) {
            if (!(trueVal[s] >= this.es.EPICThresh1)) continue;
            allBelowCutoff = false;
            break;
        }
        SAPE sapeTerm = null;
        double[] bCutoffs = new double[numSamples];
        double[] bCutoffs2 = new double[numSamples];
        double baseShift = 0.0;
        if (fp.SAPECutoff == 0.0) {
            Arrays.fill(bCutoffs, this.es.EPICThresh1);
            Arrays.fill(bCutoffs2, this.es.EPICThresh2);
        } else {
            sapeTerm = new SAPE(this.objFcn, fp.SAPECutoff, sampAbs);
            baseShift = sapeTerm.getEnergyStandalone(this.center);
            for (int s = 0; s < numSamples; ++s) {
                double shift = sapeTerm.getEnergyStandalone(sampAbs[s]);
                int n = s;
                trueVal[n] = trueVal[n] - (shift - baseShift);
                bCutoffs[s] = this.es.EPICThresh1 - shift + baseShift;
                bCutoffs2[s] = this.es.EPICThresh2 - shift + baseShift;
            }
        }
        double[] weights = null;
        weights = new double[numSamples];
        for (int s = 0; s < numSamples; ++s) {
            weights[s] = trueVal[s] > 1.0 ? 1.0 / trueVal[s] : 1.0;
        }
        double lambda = 0.0;
        double[] seriesCoeffs = null;
        EPoly ans = null;
        if (fp.PCOrder > fp.order) {
            DoubleMatrix1D[] ySamp = new DoubleMatrix1D[numSamples];
            for (int s = 0; s < numSamples; ++s) {
                ySamp[s] = PCFit.toPCBasis(sampRel[s]);
            }
            if (Double.isInfinite(fp.SAPECutoff)) {
                System.out.println("No fit needed: SAPE is full energy");
                PCFit.coeffs = new double[numParams];
            } else if (allBelowCutoff) {
                System.out.println("Analytical:");
                PCFit.coeffs = SeriesFitter.fitSeries(ySamp, trueVal, weights, lambda, false, PCFit.fullOrder, PCFit.PCOrder, PCFit.isPC, false, null, null);
            } else {
                PCFit.coeffs = SeriesFitter.fitSeriesIterative(ySamp, trueVal, weights, lambda, false, PCFit.fullOrder, bCutoffs, bCutoffs2, PCFit.PCOrder, PCFit.isPC);
            }
            ans = PCFit;
        } else {
            if (Double.isInfinite(fp.SAPECutoff)) {
                System.out.println("No fit needed: SAPE is full energy");
                seriesCoeffs = new double[numParams];
            } else if (allBelowCutoff) {
                System.out.println("Analytical:");
                seriesCoeffs = SeriesFitter.fitSeries(sampRel, trueVal, weights, lambda, false, fp.order);
            } else {
                seriesCoeffs = SeriesFitter.fitSeriesIterative(sampRel, trueVal, weights, lambda, false, fp.order, bCutoffs, bCutoffs2, fp.order, null);
            }
            ans = new EPoly(this.numDOFs, this.objFcn.getDOFs(), this.DOFmax, this.DOFmin, this.center, this.minE, seriesCoeffs, fp.order);
        }
        if (fp.SAPECutoff > 0.0) {
            ans.sapeTerm = sapeTerm;
            ans.baseSAPE = baseShift;
        }
        ans.fitDescription = fp.getDescription();
        return ans;
    }

    public double crossValidateSeries(EPoly fit, FitParams fp) {
        if (fit == null) {
            return Double.POSITIVE_INFINITY;
        }
        double meanResidual = 0.0;
        double weightSum = 0.0;
        double baseShift = 0.0;
        if (fit.sapeTerm != null) {
            baseShift = fit.sapeTerm.getEnergyStandalone(this.center);
        }
        int numSamples = sampPerParam * fp.numParams();
        double[] trueVal = new double[numSamples];
        DoubleMatrix1D[] sampRel = new DoubleMatrix1D[numSamples];
        DoubleMatrix1D[] sampAbs = new DoubleMatrix1D[numSamples];
        this.generateSamples(numSamples, sampRel, sampAbs, trueVal, numSamples / 2);
        for (int s = 0; s < numSamples; ++s) {
            DoubleMatrix1D x = sampAbs[s];
            double realVal = trueVal[s];
            double sampBCutoff = this.es.EPICThresh1;
            double sampBCutoff2 = this.es.EPICThresh2;
            double serVal = fit.evaluate(x, false, false);
            if (fit.sapeTerm != null) {
                double shift = fit.sapeTerm.getEnergyStandalone(x);
                realVal -= shift - baseShift;
                sampBCutoff -= shift - baseShift;
                sampBCutoff2 -= shift - baseShift;
                serVal -= shift - baseShift;
            }
            double weight = 1.0;
            if (realVal > 1.0) {
                weight = 1.0 / realVal;
            }
            weightSum += weight;
            if (realVal >= sampBCutoff) {
                if (SeriesFitter.isRestraintTypeActive(realVal, serVal, sampBCutoff, sampBCutoff2, false)) {
                    meanResidual += weight * (serVal - sampBCutoff) * (serVal - sampBCutoff);
                }
                if (!SeriesFitter.isRestraintTypeActive(realVal, serVal, sampBCutoff, sampBCutoff2, true)) continue;
                meanResidual += weight * (realVal - serVal) * (realVal - serVal);
                continue;
            }
            meanResidual += weight * (realVal - serVal) * (realVal - serVal);
        }
        System.out.println("CV MEAN RESIDUAL:" + (meanResidual /= weightSum));
        return meanResidual;
    }

    static void analyzeLSBRecord(ArrayList<double[]> LSBRecord) {
        double[] binMaxs = new double[]{0.1, 0.5, 0.75, 0.9, 0.99, 1.0, 1.01, 1.1, 1.5, Double.POSITIVE_INFINITY};
        int numBins = binMaxs.length;
        double[] binMins = new double[numBins];
        binMins[0] = Double.NEGATIVE_INFINITY;
        System.arraycopy(binMaxs, 0, binMins, 1, numBins - 1);
        double avgSR = 0.0;
        int[] binCounts = new int[numBins];
        int numOverHundredth = 0;
        int numOverTenth = 0;
        int numOverHalf = 0;
        int numEnum = 0;
        double minMinE = Double.POSITIVE_INFINITY;
        for (double[] rec : LSBRecord) {
            minMinE = Math.min(minMinE, rec[2]);
        }
        double avgTimeRat = 0.0;
        for (double[] rec : LSBRecord) {
            double slackRecovered = (rec[1] - rec[0]) / (rec[2] - rec[0]);
            avgSR += slackRecovered;
            if (rec[1] <= minMinE) {
                ++numEnum;
            }
            if (rec[1] > rec[2] + 0.01) {
                ++numOverHundredth;
            }
            if (rec[1] > rec[2] + 0.1) {
                ++numOverTenth;
            }
            if (rec[1] > rec[2] + 0.5) {
                ++numOverHalf;
            }
            for (int bin = 0; bin < numBins; ++bin) {
                if (!(slackRecovered <= binMaxs[bin]) || !(slackRecovered > binMins[bin])) continue;
                int n = bin;
                binCounts[n] = binCounts[n] + 1;
                break;
            }
            avgTimeRat += rec[3] / rec[4];
        }
        System.out.println("ANALYSIS OF LSB:");
        System.out.println("Total conformation count: " + LSBRecord.size());
        System.out.println("Average minimization time ratio (normal/EPIC): " + (avgTimeRat /= (double)LSBRecord.size()));
        System.out.println("Average slack recovery fraction: " + (avgSR /= (double)LSBRecord.size()));
        System.out.println(numOverHundredth + " LSBs > 0.01 over true E; " + numOverTenth + " >0.1 over, " + numOverHalf + " >0.5 over");
        System.out.println("Bin_max Bin_count");
        for (int bin = 0; bin < binMaxs.length; ++bin) {
            System.out.println(binMaxs[bin] + " " + binCounts[bin]);
        }
    }

    void sampleFromVoxel(int s, DoubleMatrix1D[] sampRel, DoubleMatrix1D[] sampAbs, double[] trueVal, ObjectiveFunction of, double[] relMin, double[] relMax, GaussianLowEnergySampler gs) {
        if (gs == null) {
            this.uniformVoxelSample(s, sampRel, sampAbs, trueVal, this.objFcn, relMin, relMax);
        } else {
            this.gaussianVoxelSample(s, sampRel, sampAbs, trueVal, this.objFcn, gs);
        }
    }

    void generateSamples(int numSamples, DoubleMatrix1D[] sampRel, DoubleMatrix1D[] sampAbs, double[] trueVal, int maxOverCutoff) {
        GaussianLowEnergySampler gs = null;
        int countOverCutoff = 0;
        double[] relMax = new double[this.numDOFs];
        double[] relMin = new double[this.numDOFs];
        for (int dof = 0; dof < this.numDOFs; ++dof) {
            relMax[dof] = this.DOFmax.get(dof) - this.center.get(dof);
            relMin[dof] = this.DOFmin.get(dof) - this.center.get(dof);
        }
        for (int s = 0; s < numSamples; ++s) {
            if (countOverCutoff < maxOverCutoff) {
                this.sampleFromVoxel(s, sampRel, sampAbs, trueVal, this.objFcn, relMin, relMax, gs);
                if (trueVal[s] > this.es.EPICThresh1) {
                    ++countOverCutoff;
                }
            } else {
                do {
                    this.sampleFromVoxel(s, sampRel, sampAbs, trueVal, this.objFcn, relMin, relMax, gs);
                } while (trueVal[s] > this.es.EPICThresh1);
            }
            if (gs != null || countOverCutoff != s + 1 || countOverCutoff < numSamples / 4) continue;
            gs = new GaussianLowEnergySampler(this.es.EPICThresh1, this.objFcn, this.DOFmin, this.DOFmax, this.center);
        }
        System.out.println("Drew " + numSamples + " samples of which " + countOverCutoff + " are over bCutoff");
    }

    void uniformVoxelSample(int s, DoubleMatrix1D[] sampRel, DoubleMatrix1D[] sampAbs, double[] trueVal, ObjectiveFunction of, double[] relMin, double[] relMax) {
        DoubleMatrix1D dx = DoubleFactory1D.dense.make(this.numDOFs);
        DoubleMatrix1D x = DoubleFactory1D.dense.make(this.numDOFs);
        for (int dof = 0; dof < this.numDOFs; ++dof) {
            double top = relMax[dof];
            double bottom = relMin[dof];
            dx.set(dof, bottom + Math.random() * (top - bottom));
            x.set(dof, this.center.get(dof) + dx.get(dof));
        }
        trueVal[s] = of.getValue(x) - this.minE;
        sampRel[s] = dx;
        sampAbs[s] = x;
    }

    void gaussianVoxelSample(int s, DoubleMatrix1D[] sampRel, DoubleMatrix1D[] sampAbs, double[] trueVal, ObjectiveFunction of, GaussianLowEnergySampler gs) {
        sampAbs[s] = gs.nextSample();
        trueVal[s] = of.getValue(sampAbs[s]) - this.minE;
        sampRel[s] = sampAbs[s].copy();
        sampRel[s].assign(this.center, Functions.minus);
        if (Math.random() > 0.5) {
            while (trueVal[s] > this.es.EPICThresh1) {
                sampRel[s].assign(Functions.mult((double)0.5));
                sampAbs[s].assign(this.center);
                sampAbs[s].assign(sampRel[s], Functions.plus);
                trueVal[s] = of.getValue(sampAbs[s]) - this.minE;
            }
        }
    }

    public FitParams raiseFitOrder(FitParams fp) {
        if (this.es.useSAPE) {
            return this.raiseFitOrderSAPEHeavy(fp);
        }
        if (this.es.usePC && fp.order < 6) {
            if (fp.PCOrder == fp.order) {
                return new FitParams(this.numDOFs, fp.order, 0.1, fp.order + 2, false, 0.0);
            }
            if (fp.PCFac == 0.1) {
                return new FitParams(this.numDOFs, fp.order, 0.01, fp.order + 2, false, 0.0);
            }
        }
        if (fp.order >= 6) {
            return null;
        }
        return new FitParams(this.numDOFs, fp.order + 2, 0.0, fp.order + 2, false, 0.0);
    }

    public FitParams raiseFitOrderSAPEHeavy(FitParams fp) {
        if (fp.PCOrder > fp.order) {
            throw new RuntimeException("ERROR: SAPE-heavy EPIC fit selection shouldn't have principal components");
        }
        if (this.es.quadOnly) {
            return this.raiseFitOrderQuadOnly(fp);
        }
        if (fp.order == 2) {
            if (fp.SAPECutoff == 0.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 3.0);
            }
            if (fp.SAPECutoff == 3.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 4.0);
            }
            if (fp.SAPECutoff == 4.0) {
                return new FitParams(this.numDOFs, 4, 0.0, 4, false, 4.0);
            }
        }
        if (fp.order == 4) {
            if (fp.SAPECutoff == 4.0) {
                return new FitParams(this.numDOFs, 4, 0.0, 4, false, 5.0);
            }
            if (fp.SAPECutoff == 5.0) {
                return new FitParams(this.numDOFs, 4, 0.0, 4, false, 7.0);
            }
            if (fp.SAPECutoff == 7.0) {
                return new FitParams(this.numDOFs, 4, 0.0, 4, false, 10.0);
            }
            if (fp.SAPECutoff == 10.0) {
                return new FitParams(this.numDOFs, 4, 0.0, 4, false, Double.POSITIVE_INFINITY);
            }
            if (Double.isInfinite(fp.SAPECutoff)) {
                return null;
            }
        }
        throw new RuntimeException("ERROR: SAPE-heavy EPIC fit selection shouldn't have this fit order: " + fp.getDescription());
    }

    public FitParams raiseFitOrderQuadOnly(FitParams fp) {
        if (fp.PCOrder > fp.order) {
            throw new RuntimeException("ERROR: SAPE-heavy EPIC fit selection shouldn't have principal components");
        }
        if (fp.order == 2) {
            if (fp.SAPECutoff == 0.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 3.0);
            }
            if (fp.SAPECutoff == 3.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 4.0);
            }
            if (fp.SAPECutoff == 4.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 5.0);
            }
            if (fp.SAPECutoff == 5.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 7.0);
            }
            if (fp.SAPECutoff == 7.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, 10.0);
            }
            if (fp.SAPECutoff == 10.0) {
                return new FitParams(this.numDOFs, 2, 0.0, 2, false, Double.POSITIVE_INFINITY);
            }
            if (Double.isInfinite(fp.SAPECutoff)) {
                return null;
            }
        }
        throw new RuntimeException("ERROR: Quad-only EPIC fit selection shouldn't have this fit order: " + fp.getDescription());
    }

    public EPoly blank() {
        EPoly ans = new EPoly(this.numDOFs, this.objFcn.getDOFs(), this.DOFmax, this.DOFmin, this.center, this.minE, null, 2);
        ans.fitDescription = "No DOFs";
        return ans;
    }

    void makeVoxelFigureData(EPoly ep) {
        System.out.println("Making voxel figure data!");
        System.out.println("minE: " + ep.minE);
        DoubleMatrix1D x = ep.center.copy();
        if (ep.numDOFs != 2) {
            System.out.println("chi trueval fitval");
            for (double chi = ep.DOFmin.get(0); chi <= ep.DOFmax.get(0); chi += 1.0) {
                x.set(0, chi);
                double trueVal = this.objFcn.getValue(x);
                double fitVal = ep.evaluate(x, true, false);
                System.out.println(chi + " " + trueVal + " " + fitVal);
            }
        } else {
            System.out.println("chi1 chi2 trueval fitval");
            for (double chi1 = ep.DOFmin.get(0); chi1 <= ep.DOFmax.get(0); chi1 += 1.0) {
                for (double chi2 = ep.DOFmin.get(1); chi2 <= ep.DOFmax.get(1); chi2 += 1.0) {
                    x.set(0, chi1);
                    x.set(1, chi2);
                    double trueVal = this.objFcn.getValue(x);
                    double fitVal = ep.evaluate(x, true, false);
                    System.out.println(chi1 + " " + chi2 + " " + trueVal + " " + fitVal);
                }
            }
        }
        throw new Error("Voxel figure data complete.");
    }
}

