/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.tupexp;

import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.tupexp.CGTupleFitter;
import edu.duke.cs.osprey.tupexp.ConfETupleExpander;
import edu.duke.cs.osprey.tupexp.FittingObjFcn;
import edu.duke.cs.osprey.tupexp.LUTESettings;
import edu.duke.cs.osprey.tupexp.TESampleSet;
import edu.duke.cs.osprey.tupexp.TupleIndexMatrix;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import java.util.TreeSet;

public abstract class TupleExpander
implements Serializable {
    int numPos;
    int[] numAllowed;
    ArrayList<RCTuple> tuples = new ArrayList();
    double[] tupleTerms;
    double constTerm = Double.NaN;
    TESampleSet trainingSamples = null;
    TESampleSet CVSamples = null;
    FittingObjFcn fof = new FittingObjFcn();
    double pruningInterval;
    static int printedUpdateNumTuples = 5;
    boolean canCheckPartialPruning = true;
    int ezPruneCount = 0;
    int numSampsPerTuple = 10;
    ArrayList<ArrayList<ArrayList<Integer>>> assignmentSets = new ArrayList();

    public TupleExpander(int numPos, int[] numAllowed, double pruningInterval, LUTESettings luteSettings) {
        this.numPos = numPos;
        this.numAllowed = numAllowed;
        this.pruningInterval = pruningInterval;
        for (int pos = 0; pos < numPos; ++pos) {
            this.assignmentSets.add(new ArrayList());
        }
        this.fof = new FittingObjFcn(pruningInterval, 0.5, luteSettings.useRelWt, luteSettings.useThreshWt);
    }

    double computeInitGMECEst() {
        double ans = Double.POSITIVE_INFINITY;
        TESampleSet tss = new TESampleSet(this);
        System.out.println("Computing initial GMEC estimate...");
        int[] testSamp = new int[this.numPos];
        Arrays.fill(testSamp, -1);
        if (tss.finishSampleDFS(testSamp) == null) {
            return Double.POSITIVE_INFINITY;
        }
        for (int iter = 0; iter < 50; ++iter) {
            boolean success;
            int[] sample = new int[this.numPos];
            do {
                Arrays.fill(sample, -1);
            } while (!(success = tss.finishSample(sample)));
            double score = this.scoreAssignmentList(sample);
            ans = Math.min(ans, score);
        }
        System.out.println("Initial GMEC estimate: " + ans);
        return ans;
    }

    public double calcExpansion(ArrayList<RCTuple> tuplesToFit) {
        if (Double.isNaN(this.constTerm)) {
            this.constTerm = this.computeInitGMECEst();
        }
        if (this.constTerm == Double.POSITIVE_INFINITY) {
            System.out.println("No conformations found for tuple expansion.  ");
            this.tupleTerms = new double[0];
            return 0.0;
        }
        System.out.println("About to calculate tuple expansion with " + tuplesToFit.size() + " tuples.");
        this.setupSamples(tuplesToFit);
        this.fitLeastSquares();
        this.trainingSamples.updateFitVals(this.fof);
        this.CVSamples.updateFitVals(this.fof);
        System.out.println("TRAINING SAMPLES: ");
        this.trainingSamples.printResids();
        System.out.println("CV SAMPLES: ");
        this.CVSamples.printResids();
        return this.CVSamples.totalResid;
    }

    private void checkHighEnergyConfs() {
        System.out.println("Checking high-energy confs");
        double lowestSampleE = Collections.min(this.trainingSamples.trueVals);
        System.out.println("Lowest sample E: " + lowestSampleE);
        double[] tupleBestE = new double[this.tuples.size()];
        Arrays.fill(tupleBestE, Double.POSITIVE_INFINITY);
        double[] tupleBestELB = new double[this.tuples.size()];
        double[] tupleBestEContE = new double[this.tuples.size()];
        double[] tupleBestLB = new double[this.tuples.size()];
        double[] tupleBestContE = new double[this.tuples.size()];
        Arrays.fill(tupleBestLB, Double.POSITIVE_INFINITY);
        Arrays.fill(tupleBestContE, Double.POSITIVE_INFINITY);
        for (int s = 0; s < this.trainingSamples.samples.size(); ++s) {
            double E = this.trainingSamples.trueVals.get(s);
            ArrayList<Integer> sampTups = this.trainingSamples.calcSampleTuples(this.trainingSamples.samples.get(s));
            double LB = ((ConfETupleExpander)this).sp.lowerBound(this.trainingSamples.samples.get(s));
            double contE = E - LB;
            for (int t : sampTups) {
                if (E < tupleBestE[t]) {
                    tupleBestE[t] = E;
                    tupleBestELB[t] = LB;
                    tupleBestEContE[t] = contE;
                }
                tupleBestLB[t] = Math.min(tupleBestLB[t], LB);
                tupleBestContE[t] = Math.min(tupleBestContE[t], contE);
            }
        }
        TreeSet<Integer> badTuples100 = new TreeSet<Integer>();
        TreeSet<Integer> badTuples50 = new TreeSet<Integer>();
        System.out.println("Tuple Best E's (then bestE LB and cont E; best LB, contE): ");
        for (int t = 0; t < this.tuples.size(); ++t) {
            System.out.println(tupleBestE[t] + " " + tupleBestELB[t] + " " + tupleBestEContE[t] + " " + tupleBestLB[t] + " " + tupleBestContE[t] + " " + this.tuples.get(t).stringListing());
            if (tupleBestE[t] > lowestSampleE + 50.0) {
                badTuples50.add(t);
            }
            if (!(tupleBestE[t] > lowestSampleE + 100.0)) continue;
            badTuples100.add(t);
        }
        System.out.println("End Tuple Best E's");
        TreeSet<Integer> suspectTuples = new TreeSet<Integer>();
        suspectTuples.addAll(badTuples50);
        System.out.println("Bad samples w/o bad tuples (E, LB, worstTupE, worstTup): ");
        for (int s = 0; s < this.trainingSamples.samples.size(); ++s) {
            if (!(this.trainingSamples.trueVals.get(s) > lowestSampleE + 100.0)) continue;
            ArrayList<Integer> sampTups = this.trainingSamples.calcSampleTuples(this.trainingSamples.samples.get(s));
            double worstTupE = Double.NEGATIVE_INFINITY;
            int worstTup = -1;
            for (int t : sampTups) {
                if (!(tupleBestE[t] > worstTupE)) continue;
                worstTup = t;
                worstTupE = tupleBestE[t];
                if (!(worstTupE > lowestSampleE + 50.0)) continue;
                break;
            }
            if (!(worstTupE <= lowestSampleE + 50.0)) continue;
            double E = this.trainingSamples.trueVals.get(s);
            double LB = ((ConfETupleExpander)this).sp.lowerBound(this.trainingSamples.samples.get(s));
            System.out.println(E + " " + LB + " " + worstTupE + " " + this.tuples.get(worstTup).stringListing());
            suspectTuples.add(worstTup);
        }
        System.out.println("End Bad samples w/o bad tuples");
        ArrayList<RCTuple> badList100 = new ArrayList<RCTuple>();
        Iterator sampTups = badTuples100.iterator();
        while (sampTups.hasNext()) {
            int t = (Integer)sampTups.next();
            badList100.add(this.tuples.get(t));
        }
        ArrayList<RCTuple> badList50 = new ArrayList<RCTuple>();
        Iterator t = badTuples50.iterator();
        while (t.hasNext()) {
            int t2 = (Integer)t.next();
            badList50.add(this.tuples.get(t2));
        }
        ArrayList<RCTuple> suspectList = new ArrayList<RCTuple>();
        Iterator iterator2 = suspectTuples.iterator();
        while (iterator2.hasNext()) {
            int t3 = (Integer)iterator2.next();
            suspectList.add(this.tuples.get(t3));
        }
        System.out.println("REDOING FITTING W/O BAD TUPLES100");
        this.redoFittingWithoutTuples(badList100);
        System.out.println("REDOING FITTING W/O BAD TUPLES50");
        this.redoFittingWithoutTuples(badList50);
        System.out.println("REDOING FITTING W/O SUSPECT TUPLES");
        this.redoFittingWithoutTuples(suspectList);
        System.out.println("HIGH CONF-E CHECK DONE");
    }

    private void redoFittingWithoutTuples(ArrayList<RCTuple> badTuples) {
        this.trainingSamples = null;
        this.CVSamples = null;
        for (RCTuple t : badTuples) {
            this.tuples.remove(t);
            this.pruneTuple(t);
        }
        ArrayList<RCTuple> tuplesToFit = this.tuples;
        this.tuples = new ArrayList();
        this.numSampsPerTuple = 10;
        this.setupSamples(tuplesToFit);
        this.fitLeastSquares();
        this.trainingSamples.updateFitVals(this.fof);
        this.CVSamples.updateFitVals(this.fof);
        System.out.println("TRAINING SAMPLES: ");
        this.trainingSamples.printResids();
        System.out.println("CV SAMPLES: ");
        this.CVSamples.printResids();
    }

    void setupSamples(ArrayList<RCTuple> tuplesToFit) {
        int t;
        int tupNum;
        this.numSampsPerTuple = 10;
        if (this.trainingSamples == null) {
            this.trainingSamples = new TESampleSet(this);
        } else {
            block0: for (tupNum = tuplesToFit.size() - 1; tupNum >= 0; --tupNum) {
                RCTuple tup = tuplesToFit.get(tupNum);
                boolean removeTuple = false;
                for (RCTuple tupHere : this.tuples) {
                    if (!tupHere.isSameTuple(tup)) continue;
                    tuplesToFit.remove(tupNum);
                    removeTuple = true;
                    break;
                }
                if (removeTuple) continue;
                for (int tupNum2 = 0; tupNum2 < tupNum; ++tupNum2) {
                    if (!tup.isSameTuple(tuplesToFit.get(tupNum2))) continue;
                    tuplesToFit.remove(tupNum);
                    continue block0;
                }
            }
        }
        System.out.println("Adding tuples to expansion and drawing training samples...");
        for (tupNum = 0; tupNum < tuplesToFit.size(); ++tupNum) {
            if (tupNum > 0 && tupNum % printedUpdateNumTuples == 0) {
                System.out.println(tupNum + " tuples added");
            }
            this.tryAddingTuple(tuplesToFit.get(tupNum));
        }
        System.out.println("EZPRUNE COUNT: " + this.ezPruneCount);
        System.out.println("Updating samples to finish training set...");
        for (t = 0; t < this.tuples.size(); ++t) {
            this.trainingSamples.updateSamples(t);
        }
        if (this.trainingSamples.samples.size() < 2 * this.tuples.size()) {
            this.numSampsPerTuple *= 2 * this.tuples.size() / this.trainingSamples.samples.size() + 1;
            for (t = 0; t < this.tuples.size(); ++t) {
                this.trainingSamples.updateSamples(t);
            }
        }
        System.out.println("Training set done.");
        System.out.println("Drawing CV samples.");
        if (this.CVSamples == null) {
            this.CVSamples = new TESampleSet(this);
            for (t = 0; t < this.tuples.size(); ++t) {
                this.CVSamples.updateSamples(t);
            }
        }
        System.out.println("CV set done.");
    }

    void fitLeastSquares() {
        int numTrainingSamples = this.trainingSamples.samples.size();
        double[] trueVals = new double[numTrainingSamples];
        for (int s = 0; s < numTrainingSamples; ++s) {
            trueVals[s] = this.trainingSamples.trueVals.get(s) - this.constTerm;
        }
        TupleIndexMatrix tim = this.getTupleIndexMatrix();
        CGTupleFitter fitter = this.fof.makeTupleFitter(tim, this.trainingSamples.samples, this.tuples.size(), trueVals);
        double[] fitTerms = fitter.doFit();
        this.tupleTerms = fitTerms;
    }

    int numSamplesNeeded(int tup) {
        return this.numSampsPerTuple;
    }

    public void tryAddingTuple(RCTuple tup) {
        boolean tupFeas = this.trainingSamples.tupleFeasible(tup);
        if (!tupFeas) {
            block0: for (TESampleSet tss : new TESampleSet[]{this.trainingSamples, this.CVSamples}) {
                if (tss == null) continue;
                for (int[] sample : tss.samples) {
                    if (!this.sampleMatchesTuple(sample, tup)) continue;
                    tupFeas = true;
                    continue block0;
                }
            }
        }
        if (tupFeas) {
            this.tuples.add(tup);
            int newTupleIndex = this.tuples.size() - 1;
            this.trainingSamples.addTuple(newTupleIndex);
            if (this.CVSamples != null) {
                this.CVSamples.addTuple(newTupleIndex);
            }
        } else {
            this.pruneTuple(tup);
            ++this.ezPruneCount;
        }
    }

    double fitValueForTuples(ArrayList<Integer> tuples) {
        double ans = this.constTerm;
        for (int tup : tuples) {
            ans += this.tupleTerms[tup];
        }
        return ans;
    }

    int getAssignmentSet(int res, ArrayList<Integer> assignmentList) {
        Collections.sort(assignmentList);
        ArrayList<ArrayList<Integer>> resASets = this.assignmentSets.get(res);
        for (int i = 0; i < resASets.size(); ++i) {
            if (resASets.get(i).size() != assignmentList.size()) continue;
            boolean listsEqual = true;
            for (int j = 0; j < assignmentList.size(); ++j) {
                if (resASets.get(i).get(j) == assignmentList.get(j)) continue;
                listsEqual = false;
            }
            if (!listsEqual) continue;
            return i;
        }
        resASets.add(assignmentList);
        return resASets.size() - 1;
    }

    void assignTupleInSample(int[] sample, RCTuple tuple) {
        for (int posCount = 0; posCount < tuple.pos.size(); ++posCount) {
            int pos = tuple.pos.get(posCount);
            int rc = tuple.RCs.get(posCount);
            if (rc >= 0) {
                sample[pos] = rc;
                continue;
            }
            ArrayList<Integer> aSet = this.assignmentSets.get(pos).get(-rc);
            int assignment = aSet.get(new Random().nextInt(aSet.size()));
            while (!this.checkAssignmentUnpruned(sample, pos, assignment)) {
                ArrayList<Integer> aSetRed = new ArrayList<Integer>();
                for (int a : aSet) {
                    if (a == assignment) continue;
                    aSetRed.add(a);
                }
                aSet = aSetRed;
                if (aSet.isEmpty()) {
                    throw new RuntimeException("ERROR: Can't find compatible assignment for tuple in sample...");
                }
                assignment = aSet.get(new Random().nextInt(aSet.size()));
            }
            sample[pos] = assignment;
        }
    }

    boolean checkAssignmentUnpruned(int[] sample, int pos, int assignment) {
        for (int pos2 = 0; pos2 < sample.length; ++pos2) {
            RCTuple pair = new RCTuple(pos, assignment, pos2, sample[pos2]);
            if (sample[pos2] != -1 && this.isPruned(pair)) {
                return false;
            }
            for (RCTuple prunedTup : this.higherOrderPrunedTuples(pair)) {
                if (!this.sampleMatchesTuple(sample, prunedTup)) continue;
                return false;
            }
        }
        return true;
    }

    boolean sampleMatchesTuple(int[] sample, RCTuple tup) {
        boolean termApplies = true;
        for (int posNum = 0; posNum < tup.pos.size(); ++posNum) {
            int pos = tup.pos.get(posNum);
            int rc = tup.RCs.get(posNum);
            if (sample[pos] == -1) {
                return false;
            }
            if (rc >= 0) {
                if (sample[pos] == rc) continue;
                termApplies = false;
                break;
            }
            ArrayList<Integer> aSet = this.assignmentSets.get(pos).get(-rc);
            if (aSet.contains(sample[pos])) continue;
            termApplies = false;
            break;
        }
        return termApplies;
    }

    public EnergyMatrix getEnergyMatrix() {
        EnergyMatrix ans = new EnergyMatrix(this.numPos, this.numAllowed, this.pruningInterval);
        ans.setConstTerm(this.constTerm);
        for (int pos = 0; pos < this.numPos; ++pos) {
            for (int rc = 0; rc < this.numAllowed[pos]; ++rc) {
                if (this.isPruned(new RCTuple(pos, rc))) {
                    ans.setOneBody(pos, rc, Double.POSITIVE_INFINITY);
                } else {
                    ans.setOneBody(pos, rc, 0.0);
                }
                for (int pos2 = 0; pos2 < pos; ++pos2) {
                    for (int rc2 = 0; rc2 < this.numAllowed[pos2]; ++rc2) {
                        RCTuple pair = new RCTuple(pos, rc, pos2, rc2);
                        if (this.isPruned(pair)) {
                            ans.setPairwise(pos, rc, pos2, rc2, Double.POSITIVE_INFINITY);
                        } else {
                            ans.setPairwise(pos, rc, pos2, rc2, 0.0);
                        }
                        for (RCTuple prunedTup : this.higherOrderPrunedTuples(pair)) {
                            ans.setTupleValue(prunedTup, Double.POSITIVE_INFINITY);
                        }
                    }
                }
            }
        }
        for (int tupNum = 0; tupNum < this.tuples.size(); ++tupNum) {
            ans.setTupleValue(this.tuples.get(tupNum), this.tupleTerms[tupNum]);
        }
        return ans;
    }

    public TupleIndexMatrix getTupleIndexMatrix() {
        TupleIndexMatrix ans = new TupleIndexMatrix(this.numPos, this.numAllowed, this.pruningInterval);
        for (int pos = 0; pos < this.numPos; ++pos) {
            for (int rc = 0; rc < this.numAllowed[pos]; ++rc) {
                ans.setOneBody(pos, rc, Integer.valueOf(-1));
                for (int pos2 = 0; pos2 < pos; ++pos2) {
                    for (int rc2 = 0; rc2 < this.numAllowed[pos2]; ++rc2) {
                        ans.setPairwise(pos, rc, pos2, rc2, Integer.valueOf(-1));
                    }
                }
            }
        }
        for (int tupNum = 0; tupNum < this.tuples.size(); ++tupNum) {
            ans.setTupleValue(this.tuples.get(tupNum), tupNum);
        }
        return ans;
    }

    abstract double scoreAssignmentList(int[] var1);

    abstract boolean isPruned(RCTuple var1);

    abstract void pruneTuple(RCTuple var1);

    abstract ArrayList<RCTuple> higherOrderPrunedTuples(RCTuple var1);
}

