/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.kstar.pfunc;

import edu.duke.cs.osprey.astar.conf.ConfAStarTree;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.kstar.pfunc.BoltzmannCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.MathTools;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class PfuncSurface {
    public final int scoreBatch;
    public final int numScoreBatches;
    public final int numEnergies;
    public final int numScores;
    private double[][] delta = null;
    private List<Trace> traces = new ArrayList<Trace>();

    public PfuncSurface(int scoreBatch, int numScoreBatches, int numEnergies) {
        this.scoreBatch = scoreBatch;
        this.numScoreBatches = numScoreBatches;
        this.numEnergies = numEnergies;
        this.numScores = scoreBatch * numScoreBatches;
    }

    public void sample(ConfEnergyCalculator confEcalc, EnergyMatrix emat) {
        this.sample(confEcalc, emat, new RCs(confEcalc.confSpace));
    }

    /*
     * WARNING - void declaration
     */
    public void sample(ConfEnergyCalculator confEcalc, EnergyMatrix emat, final RCs rcs) {
        void var10_16;
        Log.log("collecting scores...", new Object[0]);
        ArrayList<ConfSearch.ScoredConf> scoredConfs = new ArrayList<ConfSearch.ScoredConf>(this.numScores);
        ConfAStarTree astar = new ConfAStarTree.Builder(emat, rcs).setTraditional().build();
        for (int i = 0; i < this.numScoreBatches * this.scoreBatch; ++i) {
            scoredConfs.add(astar.nextConf());
        }
        BoltzmannCalculator bcalc = new BoltzmannCalculator(PartitionFunction.decimalPrecision);
        Log.log("computing energies...", new Object[0]);
        List<ConfSearch.EnergiedConf> energiedConfs = confEcalc.calcAllEnergies(scoredConfs.subList(0, this.numEnergies));
        Log.log("computing Boltzmann weights...", new Object[0]);
        ArrayList<BigDecimal> scoreWeights = new ArrayList<BigDecimal>(this.numEnergies);
        for (ConfSearch.ScoredConf scoredConf : scoredConfs) {
            scoreWeights.add(bcalc.calc(scoredConf.getScore()));
        }
        ArrayList<BigDecimal> energyWeights = new ArrayList<BigDecimal>(energiedConfs.size());
        for (ConfSearch.EnergiedConf energiedConf : energiedConfs) {
            energyWeights.add(bcalc.calc(energiedConf.getEnergy()));
        }
        class State {
            BigDecimal numConfs;
            int numScoredConfs;
            BigDecimal upperScoreWeightSum;
            BigDecimal lastSWeight;
            BigDecimal lowerScoreWeightSum;
            BigDecimal energyWeightSum;

            State() {
                this.numConfs = new BigDecimal(rcs.getNumConformations());
                this.numScoredConfs = 0;
                this.upperScoreWeightSum = BigDecimal.ZERO;
                this.lastSWeight = BigDecimal.ZERO;
                this.lowerScoreWeightSum = BigDecimal.ZERO;
                this.energyWeightSum = BigDecimal.ZERO;
            }

            double calcDelta() {
                BigDecimal unscoredBound = this.numConfs.subtract(BigDecimal.valueOf(this.numScoredConfs)).multiply(this.lastSWeight);
                BigDecimal adjustedUpperBound = this.upperScoreWeightSum.subtract(this.lowerScoreWeightSum).add(this.energyWeightSum).add(unscoredBound);
                return MathTools.bigDivide(adjustedUpperBound.subtract(this.energyWeightSum), adjustedUpperBound, PartitionFunction.decimalPrecision).doubleValue();
            }
        }
        State state = new State();
        for (double[] d : this.delta = new double[this.numScoreBatches + 1][this.numEnergies + 1]) {
            Arrays.fill(d, 1.0);
        }
        Log.log("sampling...", new Object[0]);
        boolean bl = true;
        while (var10_16 <= this.numScoreBatches) {
            for (int i = 0; i < this.scoreBatch; ++i) {
                BigDecimal scoreWeight = (BigDecimal)scoreWeights.get((int)((var10_16 - true) * this.scoreBatch + i));
                state.upperScoreWeightSum = state.upperScoreWeightSum.add(scoreWeight);
                state.lastSWeight = scoreWeight;
            }
            state.numScoredConfs += this.scoreBatch;
            state.lowerScoreWeightSum = BigDecimal.ZERO;
            state.energyWeightSum = BigDecimal.ZERO;
            this.delta[var10_16][0] = state.calcDelta();
            for (int e = 1; e <= Math.min((int)(var10_16 * this.scoreBatch), this.numEnergies); ++e) {
                BigDecimal scoreWeight = (BigDecimal)scoreWeights.get(e - 1);
                BigDecimal energyWeight = (BigDecimal)energyWeights.get(e - 1);
                state.lowerScoreWeightSum = state.lowerScoreWeightSum.add(scoreWeight);
                state.energyWeightSum = state.energyWeightSum.add(energyWeight);
                this.delta[var10_16][e] = state.calcDelta();
            }
            ++var10_16;
        }
        Log.log("surface sampling complete!", new Object[0]);
    }

    public void write(File file) throws Exception {
        try (FileWriter out = new FileWriter(file);){
            int i;
            out.write("# vtk DataFile Version 3.0\n");
            out.write("whatever\n");
            out.write("ASCII\n");
            out.write("DATASET RECTILINEAR_GRID\n");
            out.write(String.format("DIMENSIONS %d %d %d\n", this.numScoreBatches + 1, this.numEnergies + 1, 1));
            out.write(String.format("X_COORDINATES %d float\n", this.numScoreBatches + 1));
            for (i = 0; i <= this.numScoreBatches; ++i) {
                if (i % 10 > 0) {
                    out.write(" ");
                } else if (i > 0) {
                    out.write("\n");
                }
                out.write(String.format("%d", i));
            }
            out.write("\n");
            out.write(String.format("Y_COORDINATES %d float\n", this.numEnergies + 1));
            for (i = 0; i <= this.numEnergies; ++i) {
                if (i % 10 > 0) {
                    out.write(" ");
                } else if (i > 0) {
                    out.write("\n");
                }
                out.write(String.format("%d", i));
            }
            out.write("\n");
            out.write("Z_COORDINATES 1 float\n");
            out.write("0\n");
            out.write(String.format("POINT_DATA %d\n", (this.numScoreBatches + 1) * (this.numEnergies + 1)));
            out.write("FIELD fieldDelta 1\n");
            out.write(String.format("delta 1 %d float\n", (this.numScoreBatches + 1) * (this.numEnergies + 1)));
            for (int e = 0; e <= this.numEnergies; ++e) {
                for (int s = 0; s <= this.numScoreBatches; ++s) {
                    if (s % 10 > 0) {
                        out.write(" ");
                    } else if (s > 0) {
                        out.write("\n");
                    }
                    out.write(String.format("%.4f", this.delta[s][e]));
                }
                out.write("\n");
            }
        }
        Log.log("wrote to file %s", file.getAbsolutePath());
    }

    public void writeTraces(File file) {
        try (FileWriter out = new FileWriter(file);){
            int i;
            out.write("# vtk DataFile Version 3.0\n");
            out.write("whatever\n");
            out.write("ASCII\n");
            out.write("DATASET POLYDATA\n");
            int numPoints = this.traces.stream().mapToInt(trace -> trace.points.size()).sum();
            out.write(String.format("POINTS %d float\n", numPoints));
            for (Trace trace2 : this.traces) {
                for (Trace.Point p : trace2.points) {
                    out.write(String.format("%d %d %d\n", p.scores / (long)this.scoreBatch, p.energies, 0));
                }
            }
            int numLines = numPoints - this.traces.size();
            out.write(String.format("LINES %d %d\n", numLines, numLines * 3));
            int offset = 0;
            for (Trace trace3 : this.traces) {
                for (i = 0; i < trace3.points.size() - 1; ++i) {
                    out.write(String.format("2 %d %d\n", offset + i, offset + i + 1));
                }
                offset += trace3.points.size();
            }
            out.write(String.format("POINT_DATA %d\n", numPoints));
            out.write("FIELD fieldDelta 1\n");
            out.write(String.format("delta 1 %d float\n", numPoints));
            for (Trace trace3 : this.traces) {
                for (i = 0; i < trace3.points.size(); ++i) {
                    if (i % 10 > 0) {
                        out.write(" ");
                    } else if (i > 0) {
                        out.write("\n");
                    }
                    out.write(String.format("%.4f", trace3.points.get((int)i).delta));
                }
            }
            out.write("\n");
        }
        catch (IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    public class Trace {
        public final List<Point> points = new ArrayList<Point>();

        public Trace(PfuncSurface this$0) {
            this$0.traces.add(this);
        }

        public void step(long scores, long energies, double delta) {
            Point p = new Point(this);
            p.scores = scores;
            p.energies = energies;
            p.delta = delta;
            this.points.add(p);
        }

        class Point {
            long scores;
            long energies;
            double delta;

            Point(Trace this$1) {
            }
        }
    }
}

