/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.astar.conf.scoring;

import edu.duke.cs.osprey.astar.conf.ConfAStarNode;
import edu.duke.cs.osprey.astar.conf.ConfIndex;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.astar.conf.scoring.AStarScorer;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.tools.MathTools;

public class TraditionalPairwiseHScorer
implements AStarScorer {
    public final EnergyMatrix emat;
    public final RCs rcs;
    public final MathTools.Optimizer optimizer;
    private double[][][] undefinedEnergies;
    private ConfAStarNode cachedNode;
    private double[][] cachedEnergies;

    public TraditionalPairwiseHScorer(EnergyMatrix emat, RCs rcs) {
        this(emat, rcs, MathTools.Optimizer.Minimize);
    }

    public TraditionalPairwiseHScorer(EnergyMatrix emat, RCs rcs, MathTools.Optimizer optimizer) {
        this.emat = emat;
        this.rcs = rcs;
        this.optimizer = optimizer;
        int numPos = emat.getNumPos();
        this.undefinedEnergies = new double[numPos][][];
        for (int pos1 = 0; pos1 < numPos; ++pos1) {
            int numRCs = rcs.get(pos1).length;
            this.undefinedEnergies[pos1] = new double[numRCs][];
            for (int i = 0; i < numRCs; ++i) {
                int rc1 = rcs.get(pos1)[i];
                this.undefinedEnergies[pos1][i] = new double[numPos];
                for (int pos2 = 0; pos2 < pos1; ++pos2) {
                    double optEnergy = optimizer.initDouble();
                    for (int rc2 : rcs.get(pos2)) {
                        optEnergy = optimizer.opt(optEnergy, emat.getPairwise(pos1, rc1, pos2, rc2));
                    }
                    this.undefinedEnergies[pos1][i][pos2] = optEnergy;
                }
            }
        }
        this.cachedNode = null;
        this.cachedEnergies = new double[numPos][];
        for (int pos = 0; pos < numPos; ++pos) {
            this.cachedEnergies[pos] = new double[rcs.get(pos).length];
        }
    }

    @Override
    public TraditionalPairwiseHScorer make() {
        return new TraditionalPairwiseHScorer(this.emat, this.rcs, this.optimizer);
    }

    @Override
    public double calc(ConfIndex confIndex, RCs rcs) {
        double hscore = 0.0;
        this.calcCachedEnergies(confIndex, rcs);
        for (int i = 0; i < confIndex.numUndefined; ++i) {
            int pos = confIndex.undefinedPos[i];
            double optRCEnergy = this.optimizer.initDouble();
            for (int j = 0; j < rcs.get(pos).length; ++j) {
                optRCEnergy = this.optimizer.opt(optRCEnergy, this.cachedEnergies[pos][j]);
            }
            hscore += optRCEnergy;
        }
        return hscore;
    }

    @Override
    public double calcDifferential(ConfIndex confIndex, RCs rcs, int nextPos, int nextRc) {
        if (this.cachedNode != confIndex.node) {
            this.calcCachedEnergies(confIndex, rcs);
            this.cachedNode = confIndex.node;
        }
        double hscore = 0.0;
        for (int i = 0; i < confIndex.numUndefined; ++i) {
            int pos = confIndex.undefinedPos[i];
            if (pos == nextPos) continue;
            double optRCEnergy = this.optimizer.initDouble();
            double[] cachedEnergiesAtPos = this.cachedEnergies[pos];
            double[][] undefinedEnergiesAtPos = this.undefinedEnergies[pos];
            int[] rcsAtPos = rcs.get(pos);
            int n = rcsAtPos.length;
            for (int j = 0; j < n; ++j) {
                int rc = rcsAtPos[j];
                double rcEnergy = cachedEnergiesAtPos[j];
                if (pos > nextPos) {
                    rcEnergy -= undefinedEnergiesAtPos[j][nextPos];
                }
                optRCEnergy = this.optimizer.opt(optRCEnergy, rcEnergy += this.emat.getPairwise(pos, rc, nextPos, nextRc).doubleValue());
            }
            hscore += optRCEnergy;
        }
        return hscore;
    }

    private void calcCachedEnergies(ConfIndex confIndex, RCs rcs) {
        for (int i = 0; i < confIndex.numUndefined; ++i) {
            int pos1 = confIndex.undefinedPos[i];
            int[] rcs1 = rcs.get(pos1);
            int n1 = rcs1.length;
            for (int j = 0; j < n1; ++j) {
                int rc1 = rcs1[j];
                double energy = this.emat.getOneBody(pos1, rc1);
                for (int k = 0; k < confIndex.numDefined; ++k) {
                    int pos2 = confIndex.definedPos[k];
                    int rc2 = confIndex.definedRCs[k];
                    energy += this.emat.getPairwise(pos1, rc1, pos2, rc2).doubleValue();
                }
                double[] energies = this.undefinedEnergies[pos1][j];
                for (int k = 0; k < confIndex.numUndefined; ++k) {
                    int pos2 = confIndex.undefinedPos[k];
                    if (pos2 >= pos1) continue;
                    energy += energies[pos2];
                }
                this.cachedEnergies[pos1][j] = energy;
            }
        }
    }
}

