/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.tupexp;

import edu.duke.cs.osprey.tupexp.CGTupleFitter;
import edu.duke.cs.osprey.tupexp.TupleIndexMatrix;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.ConjugateGradient;
import org.apache.commons.math3.linear.RealLinearOperator;
import org.apache.commons.math3.linear.RealVector;

public class IterativeCGTupleFitter
extends CGTupleFitter {
    ArrayList<double[]> goodRegionBounds;
    double[] curFitVals = null;
    RealVector curCoeffs = null;
    double curResid = Double.POSITIVE_INFINITY;
    double damperLambda = 1.0E-4;

    public IterativeCGTupleFitter(TupleIndexMatrix tim, ArrayList<int[]> samp, int numTuples, ArrayList<double[]> goodRegionBounds) {
        this.samples = samp;
        this.numSamp = this.samples.size();
        this.numTup = numTuples;
        this.tupIndMat = tim;
        this.goodRegionBounds = goodRegionBounds;
        this.AtA = new RealLinearOperator(){

            public int getRowDimension() {
                return IterativeCGTupleFitter.this.numTup;
            }

            public int getColumnDimension() {
                return IterativeCGTupleFitter.this.numTup;
            }

            public RealVector operate(RealVector rv) throws DimensionMismatchException {
                return IterativeCGTupleFitter.this.applyAtA(rv);
            }
        };
    }

    double[] calcFitVals(RealVector rv) {
        double[] fitVals = new double[this.numSamp];
        for (int s = 0; s < this.numSamp; ++s) {
            ArrayList<Integer> sampTup = this.tupIndMat.calcSampleTuples((int[])this.samples.get(s));
            for (int t : sampTup) {
                int n = s;
                fitVals[n] = fitVals[n] + rv.getEntry(t);
            }
        }
        return fitVals;
    }

    RealVector applyAtA(RealVector rv) {
        double[] Arv = this.calcFitVals(rv);
        double[] ans = new double[this.numTup];
        for (int s = 0; s < this.numSamp; ++s) {
            if (!this.isSampleRestrained(s)) continue;
            ArrayList<Integer> sampTup = this.tupIndMat.calcSampleTuples((int[])this.samples.get(s));
            Iterator<Integer> iterator2 = sampTup.iterator();
            while (iterator2.hasNext()) {
                int t;
                int n = t = iterator2.next().intValue();
                ans[n] = ans[n] + Arv[s];
            }
        }
        if (this.curCoeffs != null) {
            for (int t = 0; t < this.numTup; ++t) {
                int n = t;
                ans[n] = ans[n] + this.damperLambda * rv.getEntry(t);
            }
        }
        return new ArrayRealVector(ans, false);
    }

    double getCurTarget(int s) {
        double[] bounds = this.goodRegionBounds.get(s);
        if (this.curFitVals == null) {
            if (bounds[0] == bounds[1]) {
                return bounds[0];
            }
            return Double.NaN;
        }
        double curFitVal = this.curFitVals[s];
        if (curFitVal < bounds[0]) {
            return bounds[0];
        }
        if (curFitVal > bounds[1]) {
            return bounds[1];
        }
        return Double.NaN;
    }

    boolean isSampleRestrained(int s) {
        double[] bounds = this.goodRegionBounds.get(s);
        if (this.curFitVals == null) {
            return bounds[0] == bounds[1];
        }
        double curFitVal = this.curFitVals[s];
        return curFitVal < bounds[0] || curFitVal > bounds[1];
    }

    RealVector calcRHS() {
        double[] atb = new double[this.numTup];
        for (int s = 0; s < this.numSamp; ++s) {
            double curTarget = this.getCurTarget(s);
            if (Double.isNaN(curTarget)) continue;
            ArrayList<Integer> sampTup = this.tupIndMat.calcSampleTuples((int[])this.samples.get(s));
            Iterator<Integer> iterator2 = sampTup.iterator();
            while (iterator2.hasNext()) {
                int t;
                int n = t = iterator2.next().intValue();
                atb[n] = atb[n] + curTarget;
            }
        }
        if (this.curCoeffs != null) {
            for (int t = 0; t < this.numTup; ++t) {
                int n = t;
                atb[n] = atb[n] + this.damperLambda * this.curCoeffs.getEntry(t);
            }
        }
        this.Atb = new ArrayRealVector(atb);
        return this.Atb;
    }

    double calcResidual(double[] fitVals) {
        double resid = 0.0;
        for (int s = 0; s < this.numSamp; ++s) {
            double dev = 0.0;
            double[] bounds = this.goodRegionBounds.get(s);
            if (fitVals[s] < bounds[0]) {
                dev = fitVals[s] - bounds[0];
            } else if (fitVals[s] > bounds[1]) {
                dev = fitVals[s] - bounds[1];
            }
            resid += dev * dev;
        }
        return resid / (double)this.numSamp;
    }

    boolean checkDone(double[] oldFitVals, double[] newFitVals) {
        if (oldFitVals == null) {
            return false;
        }
        double tol = 1.0E-6;
        for (int s = 0; s < this.numSamp; ++s) {
            double[] bounds = this.goodRegionBounds.get(s);
            if (!(bounds[0] < bounds[1]) || !(oldFitVals[s] < bounds[0] ? newFitVals[s] > bounds[0] + tol : (oldFitVals[s] > bounds[1] ? newFitVals[s] < bounds[1] - tol : newFitVals[s] < bounds[0] - tol || newFitVals[s] > bounds[1] + tol))) continue;
            return false;
        }
        return true;
    }

    @Override
    double[] doFit() {
        ConjugateGradient cg = new ConjugateGradient(100000, 1.0E-6, false);
        long startTime = System.currentTimeMillis();
        while (true) {
            double iterStartTime = System.currentTimeMillis();
            this.Atb = this.calcRHS();
            RealVector ans = cg.solve(this.AtA, this.Atb);
            double[] newFitVals = this.calcFitVals(ans);
            System.out.println("Conjugate gradient fitting time (ms): " + ((double)System.currentTimeMillis() - iterStartTime));
            double resid = this.calcResidual(newFitVals);
            System.out.println("Step residual: " + resid);
            if (resid > this.curResid) {
                System.out.println("Iterative conjugate gradient fitting time (ms): " + (System.currentTimeMillis() - startTime));
                return this.curCoeffs.toArray();
            }
            if (resid > this.curResid - 1.0E-4) {
                System.out.println("Iterative conjugate gradient fitting time (ms): " + (System.currentTimeMillis() - startTime));
                return ans.toArray();
            }
            this.curCoeffs = ans;
            this.curFitVals = newFitVals;
            this.curResid = resid;
        }
    }
}

