/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.voxq;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.jet.math.Functions;
import edu.duke.cs.osprey.minimization.MoleculeModifierAndScorer;
import edu.duke.cs.osprey.voxq.IntraVoxelSampler;
import java.util.ArrayList;
import java.util.Collections;

public class VoxelsDeltaG {
    static int sampleBatchSize = 100;
    int numDOFs;
    ArrayList<Sample> samples1;
    ArrayList<Sample> samples2;
    IntraVoxelSampler sampler1;
    IntraVoxelSampler sampler2;
    SampleNormalization sn1;
    SampleNormalization sn2;
    double estDeltaG = 0.0;
    double integRelErr1 = Double.POSITIVE_INFINITY;
    double integRelErr2 = Double.POSITIVE_INFINITY;

    public VoxelsDeltaG(MoleculeModifierAndScorer mms1, MoleculeModifierAndScorer mms2, boolean alignByEnergy) {
        this.sampler1 = new IntraVoxelSampler(mms1);
        this.sampler2 = new IntraVoxelSampler(mms2);
        this.numDOFs = this.sampler1.numDOFs;
        if (this.sampler2.numDOFs != this.numDOFs) {
            throw new RuntimeException("ERROR: Not supporting delta G for voxels w/ different # DOFs currently...");
        }
        if (alignByEnergy) {
            ArrayList<DoubleMatrix1D> fullSamples1 = new ArrayList<DoubleMatrix1D>();
            ArrayList<DoubleMatrix1D> fullSamples2 = new ArrayList<DoubleMatrix1D>();
            for (int n = 0; n < sampleBatchSize; ++n) {
                DoubleMatrix1D samp1 = this.sampler1.nextSample();
                fullSamples1.add(samp1);
                DoubleMatrix1D samp2 = this.sampler2.nextSample();
                fullSamples2.add(samp2);
            }
            this.sn1 = new SampleNormalization(this, fullSamples1);
            this.sn2 = new SampleNormalization(this, fullSamples2);
        } else {
            this.sn1 = new SampleNormalization(this, mms1.getConstraints());
            this.sn2 = new SampleNormalization(this, mms2.getConstraints());
        }
    }

    public double estDeltaG(double stdErr) {
        double integRelErrTarget = stdErr / 0.593050165;
        this.samples1 = new ArrayList();
        this.samples2 = new ArrayList();
        for (int n = 0; n < sampleBatchSize; ++n) {
            this.samples1.add(new Sample(this, this.sampler1.nextSample(), true));
            this.samples2.add(new Sample(this, this.sampler2.nextSample(), false));
        }
        double newDeltaG;
        block1: while (!(Math.abs((newDeltaG = this.curDeltaGEstimate()) - this.estDeltaG) < stdErr) || !(this.totIntegRelErr() < integRelErrTarget)) {
            this.estDeltaG = newDeltaG;
            int n = 0;
            while (true) {
                if (n >= sampleBatchSize) continue block1;
                this.samples1.add(new Sample(this, this.sampler1.nextSample(), true));
                this.samples2.add(new Sample(this, this.sampler2.nextSample(), false));
                ++n;
            }
            break;
        }
        return newDeltaG;
    }

    private double totIntegRelErr() {
        return Math.sqrt(this.integRelErr1 * this.integRelErr1 + this.integRelErr2 * this.integRelErr2);
    }

    private static double fd(double E) {
        return 1.0 / (1.0 + Math.exp(E / 0.593050165));
    }

    static double mean(ArrayList<Double> arr) {
        double ans = 0.0;
        for (double a : arr) {
            ans += a;
        }
        return ans / (double)arr.size();
    }

    static double relStdDev(ArrayList<Double> arr, double mean) {
        double ans = 0.0;
        for (double a : arr) {
            ans += (a - mean) * (a - mean);
        }
        ans /= (double)(arr.size() - 1);
        ans = Math.sqrt(ans);
        return ans /= mean;
    }

    double curDeltaGEstimate() {
        ArrayList<Double> f1 = new ArrayList<Double>();
        ArrayList<Double> f2 = new ArrayList<Double>();
        for (Sample s : this.samples1) {
            f1.add(VoxelsDeltaG.fd(s.Ediff - this.estDeltaG));
        }
        for (Sample s : this.samples2) {
            f2.add(VoxelsDeltaG.fd(this.estDeltaG - s.Ediff));
        }
        double integ1 = VoxelsDeltaG.mean(f1);
        double integ2 = VoxelsDeltaG.mean(f2);
        this.integRelErr1 = VoxelsDeltaG.relStdDev(f1, integ1) / Math.sqrt(f1.size());
        this.integRelErr2 = VoxelsDeltaG.relStdDev(f2, integ2) / Math.sqrt(f2.size());
        return this.estDeltaG - 0.593050165 * (Math.log(integ1 * this.sn2.jacDet) - Math.log(integ2 * this.sn1.jacDet));
    }

    public int numSamplesNeeded() {
        return this.samples1.size();
    }

    private class SampleNormalization {
        DoubleMatrix1D center;
        DoubleMatrix1D scaling;
        double jacDet;

        SampleNormalization(VoxelsDeltaG voxelsDeltaG, DoubleMatrix1D[] constr) {
            this.center = constr[0].copy();
            this.center.assign(constr[1], Functions.plus);
            this.center.assign(Functions.mult((double)0.5));
            this.scaling = constr[1].copy();
            this.scaling.assign(constr[0], Functions.minus);
            this.jacDet = 1.0;
            for (double el : this.scaling.toArray()) {
                this.jacDet *= el;
            }
        }

        SampleNormalization(VoxelsDeltaG voxelsDeltaG, ArrayList<DoubleMatrix1D> fullSamples) {
            this.center = DoubleFactory1D.dense.make(voxelsDeltaG.numDOFs);
            this.scaling = DoubleFactory1D.dense.make(voxelsDeltaG.numDOFs);
            this.jacDet = 1.0;
            for (int dofNum = 0; dofNum < voxelsDeltaG.numDOFs; ++dofNum) {
                ArrayList<Double> vals = new ArrayList<Double>();
                int numSamples = fullSamples.size();
                double cen = 0.0;
                for (DoubleMatrix1D samp : fullSamples) {
                    vals.add(samp.get(dofNum));
                    cen += samp.get(dofNum);
                }
                this.center.set(dofNum, cen / (double)numSamples);
                Collections.sort(vals);
                double sc = (Double)vals.get(3 * numSamples / 4) - (Double)vals.get(numSamples / 4);
                this.jacDet *= sc;
                this.scaling.set(dofNum, sc);
            }
        }

        DoubleMatrix1D unnormalize(DoubleMatrix1D y) {
            return y.copy().assign(this.scaling, Functions.mult).assign(this.center, Functions.plus);
        }

        DoubleMatrix1D normalize(DoubleMatrix1D z) {
            return z.copy().assign(this.center, Functions.minus).assign(this.scaling, Functions.div);
        }
    }

    private class Sample {
        double Ediff;

        Sample(VoxelsDeltaG voxelsDeltaG, DoubleMatrix1D DOFVals, boolean isVox1) {
            DoubleMatrix1D z2;
            DoubleMatrix1D z1;
            if (isVox1) {
                z1 = DOFVals;
                z2 = voxelsDeltaG.sn2.unnormalize(voxelsDeltaG.sn1.normalize(DOFVals));
                if (voxelsDeltaG.sampler2.mms.isOutOfRange(z2)) {
                    this.Ediff = Double.POSITIVE_INFINITY;
                    return;
                }
            } else {
                z2 = DOFVals;
                z1 = voxelsDeltaG.sn1.unnormalize(voxelsDeltaG.sn2.normalize(DOFVals));
                if (voxelsDeltaG.sampler1.mms.isOutOfRange(z1)) {
                    this.Ediff = Double.NEGATIVE_INFINITY;
                    return;
                }
            }
            this.Ediff = voxelsDeltaG.sampler2.mms.getValue(z2) - voxelsDeltaG.sampler1.mms.getValue(z1);
        }
    }
}

