/*
 * Decompiled with CFR 0.152.
 */
package com.joptimizer.optimizers;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.SeqBlas;
import cern.jet.math.Functions;
import cern.jet.math.Mult;
import com.joptimizer.functions.FunctionsUtils;
import com.joptimizer.optimizers.BasicPhaseIPDM;
import com.joptimizer.optimizers.OptimizationRequestHandler;
import com.joptimizer.optimizers.OptimizationResponse;
import com.joptimizer.solvers.BasicKKTSolver;
import com.joptimizer.solvers.KKTSolver;
import com.joptimizer.util.ColtUtils;
import com.joptimizer.util.Utils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class PrimalDualMethod
extends OptimizationRequestHandler {
    private KKTSolver kktSolver;
    private Log log = LogFactory.getLog((String)this.getClass().getName());

    public int optimize() throws Exception {
        this.log.info((Object)"optimize");
        long tStart = System.currentTimeMillis();
        OptimizationResponse response = new OptimizationResponse();
        DoubleMatrix1D X0 = this.getInitialPoint();
        if (X0 == null) {
            DoubleMatrix1D X0NF = this.getNotFeasibleInitialPoint();
            if (X0NF != null) {
                double rPriX0NFNorm = Math.sqrt(this.ALG.norm2(this.rPri(X0NF)));
                DoubleMatrix1D fiX0NF = this.getFi(X0NF);
                int maxIndex = Utils.getMaxIndex(fiX0NF);
                double maxValue = fiX0NF.get(maxIndex);
                if (this.log.isDebugEnabled()) {
                    this.log.debug((Object)("rPriX0NFNorm :  " + rPriX0NFNorm));
                    this.log.debug((Object)("X0NF         :  " + ArrayUtils.toString((Object)X0NF.toArray())));
                    this.log.debug((Object)("fiX0NF       :  " + ArrayUtils.toString((Object)fiX0NF.toArray())));
                }
                if (maxValue < 0.0 && rPriX0NFNorm <= this.getToleranceFeas()) {
                    this.log.debug((Object)"the provided initial point is already feasible");
                    X0 = X0NF;
                }
            }
            if (X0 == null) {
                BasicPhaseIPDM bf1 = new BasicPhaseIPDM(this);
                X0 = bf1.findFeasibleInitialPoint();
            }
        }
        DoubleMatrix1D fiX0 = this.getFi(X0);
        int maxIndex = Utils.getMaxIndex(fiX0);
        double maxValue = fiX0.get(maxIndex);
        double rPriX0Norm = Math.sqrt(this.ALG.norm2(this.rPri(X0)));
        if (maxValue >= 0.0 || rPriX0Norm > this.getToleranceFeas()) {
            this.log.debug((Object)("rPriX0Norm  : " + rPriX0Norm));
            this.log.debug((Object)("ineqX0      : " + ArrayUtils.toString((Object)fiX0.toArray())));
            this.log.debug((Object)("max ineq index: " + maxIndex));
            this.log.debug((Object)("max ineq value: " + maxValue));
            throw new Exception("initial point must be strictly feasible");
        }
        DoubleMatrix1D V0 = this.getA() != null ? this.F1.make(this.getA().rows()) : this.F1.make(0);
        DoubleMatrix1D L0 = this.getInitialLagrangian();
        if (L0 != null) {
            for (int j = 0; j < L0.size(); ++j) {
                if (!(L0.get(j) <= 0.0)) continue;
                throw new IllegalArgumentException("initial lagrangian must be strictly > 0");
            }
        } else {
            L0 = this.F1.make(this.getMieq(), Math.min(1.0, (double)this.getDim() / (double)this.getMieq()));
        }
        if (this.log.isDebugEnabled()) {
            this.log.debug((Object)("X0:  " + ArrayUtils.toString((Object)X0.toArray())));
            this.log.debug((Object)("V0:  " + ArrayUtils.toString((Object)V0.toArray())));
            this.log.debug((Object)("L0:  " + ArrayUtils.toString((Object)L0.toArray())));
            this.log.debug((Object)("toleranceFeas:  " + this.getToleranceFeas()));
            this.log.debug((Object)("tolerance    :  " + this.getTolerance()));
        }
        DoubleMatrix1D X = X0;
        DoubleMatrix1D V = V0;
        DoubleMatrix1D L = L0;
        double previousF0X = Double.NaN;
        double previousRPriXNorm = Double.NaN;
        double previousRDualXLVNorm = Double.NaN;
        double previousSurrDG = Double.NaN;
        int iteration = 0;
        while (true) {
            DoubleMatrix1D stepV;
            if (++iteration == this.getMaxIteration() + 1) {
                response.setReturnCode(2);
                this.log.error((Object)"Max iterations limit reached");
                throw new Exception("Max iterations limit reached");
            }
            double F0X = this.getF0(X);
            if (this.log.isDebugEnabled()) {
                this.log.debug((Object)("iteration: " + iteration));
                this.log.debug((Object)("X=" + ArrayUtils.toString((Object)X.toArray())));
                this.log.debug((Object)("L=" + ArrayUtils.toString((Object)L.toArray())));
                this.log.debug((Object)("V=" + ArrayUtils.toString((Object)V.toArray())));
                this.log.debug((Object)("f0(X)=" + F0X));
            }
            DoubleMatrix1D gradF0X = this.getGradF0(X);
            DoubleMatrix1D fiX = this.getFi(X);
            DoubleMatrix2D GradFiX = this.getGradFi(X);
            DoubleMatrix2D[] HessFiX = this.getHessFi(X);
            double surrDG = this.getSurrogateDualityGap(fiX, L);
            double t = this.getMu() * (double)this.getMieq() / surrDG;
            this.log.debug((Object)("t:  " + t));
            DoubleMatrix1D rPriX = this.rPri(X);
            DoubleMatrix1D rCentXLt = this.rCent(fiX, L, t);
            DoubleMatrix1D rDualXLV = this.rDual(GradFiX, gradF0X, L, V);
            double rPriXNorm = Math.sqrt(this.ALG.norm2(rPriX));
            double rCentXLtNorm = Math.sqrt(this.ALG.norm2(rCentXLt));
            double rDualXLVNorm = Math.sqrt(this.ALG.norm2(rDualXLV));
            double normRXLVt = Math.sqrt(Math.pow(rPriXNorm, 2.0) + Math.pow(rCentXLtNorm, 2.0) + Math.pow(rDualXLVNorm, 2.0));
            this.log.debug((Object)("rPri  norm: " + rPriXNorm));
            this.log.debug((Object)("rCent norm: " + rCentXLtNorm));
            this.log.debug((Object)("rDual norm: " + rDualXLVNorm));
            this.log.debug((Object)("surrDG    : " + surrDG));
            if (this.checkCustomExitConditions(X)) {
                response.setReturnCode(0);
                break;
            }
            if (rPriXNorm <= this.getToleranceFeas() && rDualXLVNorm <= this.getToleranceFeas() && surrDG <= this.getTolerance()) {
                response.setReturnCode(0);
                break;
            }
            if (this.isCheckProgressConditions()) {
                if (!Double.isNaN(previousRPriXNorm) && !Double.isNaN(previousRDualXLVNorm) && !Double.isNaN(previousSurrDG) && (previousRPriXNorm <= rPriXNorm && rPriXNorm >= this.getToleranceFeas() || previousRDualXLVNorm <= rDualXLVNorm && rDualXLVNorm >= this.getToleranceFeas())) {
                    this.log.error((Object)"No progress achieved, exit iterations loop without desired accuracy");
                    response.setReturnCode(2);
                    throw new Exception("No progress achieved, exit iterations loop without desired accuracy");
                }
                previousRPriXNorm = rPriXNorm;
                previousRDualXLVNorm = rDualXLVNorm;
                previousSurrDG = surrDG;
            }
            DoubleMatrix2D HessSum = this.getHessF0(X);
            for (int j = 0; j < this.getMieq(); ++j) {
                if (HessFiX[j] == FunctionsUtils.ZEROES_MATRIX_PLACEHOLDER) continue;
                HessSum = ColtUtils.add(HessSum, HessFiX[j], L.get(j));
            }
            DoubleMatrix2D GradSum = this.F2.make(this.getDim(), this.getDim());
            for (int j = 0; j < this.getMieq(); ++j) {
                double c = -L.getQuick(j) / fiX.getQuick(j);
                DoubleMatrix1D g = GradFiX.viewRow(j);
                SeqBlas.seqBlas.dger(c, g, g, GradSum);
            }
            DoubleMatrix2D Hpd = HessSum.assign(GradSum, Functions.plus);
            DoubleMatrix1D gradSum = this.F1.make(this.getDim());
            for (int j = 0; j < this.getMieq(); ++j) {
                gradSum = ColtUtils.add(gradSum, GradFiX.viewRow(j), 1.0 / (-t * fiX.get(j)));
            }
            DoubleMatrix1D g = null;
            g = this.getAT() == null ? ColtUtils.add(gradF0X, gradSum) : ColtUtils.add(ColtUtils.add(gradF0X, gradSum), this.ALG.mult(this.getAT(), V));
            if (this.kktSolver == null) {
                this.kktSolver = new BasicKKTSolver(this.request.isRescalingDisabled());
            }
            if (this.isCheckKKTSolutionAccuracy()) {
                this.kktSolver.setCheckKKTSolutionAccuracy(true);
                this.kktSolver.setToleranceKKT(this.getToleranceKKT());
            }
            this.kktSolver.setHMatrix(Hpd);
            this.kktSolver.setGVector(g);
            if (this.getA() != null) {
                this.kktSolver.setAMatrix(this.getA());
                this.kktSolver.setHVector(rPriX);
            }
            DoubleMatrix1D[] sol = this.kktSolver.solve();
            DoubleMatrix1D stepX = sol[0];
            DoubleMatrix1D doubleMatrix1D = stepV = sol[1] != null ? sol[1] : this.F1.make(0);
            if (this.log.isDebugEnabled()) {
                this.log.debug((Object)("stepX: " + ArrayUtils.toString((Object)stepX.toArray())));
                this.log.debug((Object)("stepV: " + ArrayUtils.toString((Object)stepV.toArray())));
            }
            DoubleMatrix1D stepL = null;
            DoubleMatrix1D a2 = rCentXLt.copy().assign(fiX, Functions.div);
            DoubleMatrix1D b2 = this.ALG.mult(GradFiX, stepX).assign(L, Functions.mult);
            DoubleMatrix1D c2 = b2.assign(fiX, Functions.div);
            stepL = ColtUtils.add(a2, c2, -1.0);
            if (this.log.isDebugEnabled()) {
                this.log.debug((Object)("stepL: " + ArrayUtils.toString((Object)stepL.toArray())));
            }
            double sMax = Double.MAX_VALUE;
            for (int j = 0; j < this.getMieq(); ++j) {
                if (!(stepL.get(j) < 0.0)) continue;
                sMax = Math.min(-L.get(j) / stepL.get(j), sMax);
            }
            sMax = Math.min(1.0, sMax);
            double s = 0.99 * sMax;
            DoubleMatrix1D X1 = this.F1.make(X.size());
            DoubleMatrix1D L1 = this.F1.make(L.size());
            DoubleMatrix1D V1 = this.F1.make(V.size());
            DoubleMatrix1D fiX1 = null;
            DoubleMatrix1D gradF0X1 = null;
            DoubleMatrix2D GradFiX1 = null;
            DoubleMatrix1D rPriX1 = null;
            DoubleMatrix1D rCentX1L1t = null;
            DoubleMatrix1D rDualX1L1V1 = null;
            int cnt = 0;
            boolean areAllNegative = true;
            while (cnt < 500) {
                ++cnt;
                X1 = stepX.copy().assign(Mult.mult(s)).assign(X, Functions.plus);
                DoubleMatrix1D ineqValueX1 = this.getFi(X1);
                areAllNegative = true;
                for (int j = 0; areAllNegative && j < this.getMieq(); ++j) {
                    areAllNegative = Double.compare(ineqValueX1.get(j), 0.0) < 0;
                }
                if (areAllNegative) break;
                s = this.getBeta() * s;
            }
            if (!areAllNegative) {
                throw new Exception("Optimization failed: impossible to remain within the faesible region");
            }
            this.log.debug((Object)("s: " + s));
            double previousNormRX1L1V1t = Double.NaN;
            cnt = 0;
            while (cnt < 500) {
                ++cnt;
                X1 = ColtUtils.add(X, stepX, s);
                L1 = ColtUtils.add(L, stepL, s);
                V1 = ColtUtils.add(V, stepV, s);
                if (this.isInDomainF0(X1)) {
                    fiX1 = this.getFi(X1);
                    gradF0X1 = this.getGradF0(X1);
                    GradFiX1 = this.getGradFi(X1);
                    rPriX1 = this.rPri(X1);
                    rCentX1L1t = this.rCent(fiX1, L1, t);
                    rDualX1L1V1 = this.rDual(GradFiX1, gradF0X1, L1, V1);
                    double normRX1L1V1t = Math.sqrt(this.ALG.norm2(rPriX1) + this.ALG.norm2(rCentX1L1t) + this.ALG.norm2(rDualX1L1V1));
                    if (normRX1L1V1t <= (1.0 - this.getAlpha() * s) * normRXLVt) break;
                    if (!Double.isNaN(previousNormRX1L1V1t) && previousNormRX1L1V1t <= normRX1L1V1t) {
                        this.log.warn((Object)"No progress achieved in backtracking with norm");
                        break;
                    }
                    previousNormRX1L1V1t = normRX1L1V1t;
                }
                s = this.getBeta() * s;
            }
            X = X1;
            V = V1;
            L = L1;
        }
        long tStop = System.currentTimeMillis();
        this.log.debug((Object)("time: " + (tStop - tStart)));
        this.log.debug((Object)("sol : " + ArrayUtils.toString((Object)X.toArray())));
        this.log.debug((Object)("ret code: " + response.getReturnCode()));
        response.setSolution(X.toArray());
        this.setOptimizationResponse(response);
        return response.getReturnCode();
    }

    private double getSurrogateDualityGap(DoubleMatrix1D fiX, DoubleMatrix1D L) {
        return -this.ALG.mult(fiX, L);
    }

    private DoubleMatrix1D rDual(DoubleMatrix2D GradFiX, DoubleMatrix1D gradF0X, DoubleMatrix1D L, DoubleMatrix1D V) {
        if (this.getA() == null) {
            return ColtUtils.zMultTranspose(GradFiX, L, gradF0X, 1.0);
        }
        return ColtUtils.zMultTranspose(this.getA(), V, ColtUtils.zMultTranspose(GradFiX, L, gradF0X, 1.0), 1.0);
    }

    private DoubleMatrix1D rCent(DoubleMatrix1D fiX, DoubleMatrix1D L, double t) {
        DoubleMatrix1D ret = this.F1.make(L.size());
        for (int i = 0; i < ret.size(); ++i) {
            ret.setQuick(i, -L.getQuick(i) * fiX.getQuick(i) - 1.0 / t);
        }
        return ret;
    }

    public void setKKTSolver(KKTSolver kktSolver) {
        this.kktSolver = kktSolver;
    }
}

