/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.Attribute;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.RegressionTree;
import smile.sequence.SequenceLabeler;
import smile.sort.QuickSort;
import smile.util.MulticoreExecutor;

public class CRF
implements SequenceLabeler<double[]> {
    private static final Logger logger = LoggerFactory.getLogger(CRF.class);
    private int numClasses;
    private int numFeatures = -1;
    private TreePotentialFunction[] potentials;
    private boolean viterbi = false;

    public double[] featureset(double[] features, int label) {
        double[] fs = new double[features.length + 1];
        System.arraycopy(features, 0, fs, 0, features.length);
        fs[features.length] = label;
        return fs;
    }

    public int[] featureset(int[] features, int label) {
        int[] fs = new int[features.length + 1];
        System.arraycopy(features, 0, fs, 0, features.length);
        fs[features.length] = this.numFeatures + label;
        return fs;
    }

    private CRF(int numClasses, double eta) {
        this.numClasses = numClasses;
        this.potentials = new TreePotentialFunction[numClasses];
        for (int i = 0; i < numClasses; ++i) {
            this.potentials[i] = new TreePotentialFunction(eta);
        }
    }

    private CRF(int numFeatures, int numClasses, double eta) {
        this.numFeatures = numFeatures;
        this.numClasses = numClasses;
        this.potentials = new TreePotentialFunction[numClasses];
        for (int i = 0; i < numClasses; ++i) {
            this.potentials[i] = new TreePotentialFunction(eta);
        }
    }

    public boolean isViterbi() {
        return this.viterbi;
    }

    public CRF setViterbi(boolean viterbi) {
        this.viterbi = viterbi;
        return this;
    }

    public int[] predict(double[][] x) {
        if (this.viterbi) {
            return this.predictViterbi(x);
        }
        return this.predictForwardBackward(x);
    }

    public int[] predict(int[][] x) {
        if (this.viterbi) {
            return this.predictViterbi(x);
        }
        return this.predictForwardBackward(x);
    }

    private int[] predictForwardBackward(double[][] x) {
        int n = x.length;
        TrellisNode[][] trellis = this.getTrellis(x);
        double[] scaling = new double[n];
        this.forward(trellis, scaling);
        this.backward(trellis);
        int[] label = new int[n];
        double[] p = new double[this.numClasses];
        for (int i = 0; i < n; ++i) {
            TrellisNode[] ti = trellis[i];
            for (int j = 0; j < this.numClasses; ++j) {
                TrellisNode tij = ti[j];
                p[j] = tij.alpha * tij.beta;
            }
            double max = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < this.numClasses; ++j) {
                if (!(max < p[j])) continue;
                max = p[j];
                label[i] = j;
            }
        }
        return label;
    }

    private int[] predictViterbi(double[][] x) {
        int i;
        int n = x.length;
        double[][] trellis = new double[n][this.numClasses];
        int[][] psy = new int[n][this.numClasses];
        int p = x[0].length;
        double[] features = this.featureset(x[0], this.numClasses);
        double[] t0 = trellis[0];
        int[] p0 = psy[0];
        for (int j = 0; j < this.numClasses; ++j) {
            t0[j] = this.potentials[j].f(features);
            p0[j] = 0;
        }
        for (int t = 1; t < n; ++t) {
            System.arraycopy(x[t], 0, features, 0, p);
            double[] tt = trellis[t];
            double[] tt1 = trellis[t - 1];
            int[] pt = psy[t];
            for (i = 0; i < this.numClasses; ++i) {
                double max = Double.NEGATIVE_INFINITY;
                int maxPsy = 0;
                TreePotentialFunction pi = this.potentials[i];
                for (int j = 0; j < this.numClasses; ++j) {
                    features[p] = j;
                    double delta = pi.f(features) + tt1[j];
                    if (!(max < delta)) continue;
                    max = delta;
                    maxPsy = j;
                }
                tt[i] = max;
                pt[i] = maxPsy;
            }
        }
        int[] label = new int[n];
        double[] tn1 = trellis[n - 1];
        double max = Double.NEGATIVE_INFINITY;
        for (i = 0; i < this.numClasses; ++i) {
            if (!(max < tn1[i])) continue;
            max = tn1[i];
            label[n - 1] = i;
        }
        int t = n - 1;
        while (t-- > 0) {
            label[t] = psy[t + 1][label[t + 1]];
        }
        return label;
    }

    private int[] predictForwardBackward(int[][] x) {
        int n = x.length;
        TrellisNode[][] trellis = this.getTrellis(x);
        double[] scaling = new double[n];
        this.forward(trellis, scaling);
        this.backward(trellis);
        int[] label = new int[n];
        double[] p = new double[this.numClasses];
        for (int i = 0; i < n; ++i) {
            TrellisNode[] ti = trellis[i];
            for (int j = 0; j < this.numClasses; ++j) {
                TrellisNode tij = ti[j];
                p[j] = tij.alpha * tij.beta;
            }
            double max = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < this.numClasses; ++j) {
                if (!(max < p[j])) continue;
                max = p[j];
                label[i] = j;
            }
        }
        return label;
    }

    private int[] predictViterbi(int[][] x) {
        int i;
        int n = x.length;
        double[][] trellis = new double[n][this.numClasses];
        int[][] psy = new int[n][this.numClasses];
        int p = x[0].length;
        double[] t0 = trellis[0];
        int[] p0 = psy[0];
        int[] features = this.featureset(x[0], this.numClasses);
        for (int j = 0; j < this.numClasses; ++j) {
            t0[j] = this.potentials[j].f(features);
            p0[j] = 0;
        }
        for (int t = 1; t < n; ++t) {
            System.arraycopy(x[t], 0, features, 0, p);
            double[] tt = trellis[t];
            double[] tt1 = trellis[t - 1];
            int[] pt = psy[t];
            for (i = 0; i < this.numClasses; ++i) {
                double max = Double.NEGATIVE_INFINITY;
                int maxPsy = 0;
                for (int j = 0; j < this.numClasses; ++j) {
                    features[p] = this.numFeatures + j;
                    double delta = this.potentials[i].f(features) + tt1[j];
                    if (!(max < delta)) continue;
                    max = delta;
                    maxPsy = j;
                }
                tt[i] = max;
                pt[i] = maxPsy;
            }
        }
        int[] label = new int[n];
        double[] tn1 = trellis[n - 1];
        double max = Double.NEGATIVE_INFINITY;
        for (i = 0; i < this.numClasses; ++i) {
            if (!(max < tn1[i])) continue;
            max = tn1[i];
            label[n - 1] = i;
        }
        int t = n - 1;
        while (t-- > 0) {
            label[t] = psy[t + 1][label[t + 1]];
        }
        return label;
    }

    private void forward(TrellisNode[][] trellis, double[] scaling) {
        int i;
        int n = trellis.length;
        TrellisNode[] t0 = trellis[0];
        for (i = 0; i < this.numClasses; ++i) {
            int k;
            TrellisNode t0i = t0[i];
            TreePotentialFunction pi = this.potentials[i];
            if (this.numFeatures <= 0) {
                for (k = t0i.age; k < pi.trees.size(); ++k) {
                    t0i.scores[0] = t0i.scores[0] + pi.eta * ((RegressionTree)pi.trees.get(k)).predict(t0i.samples[0]);
                }
            } else {
                for (k = t0i.age; k < pi.trees.size(); ++k) {
                    t0i.scores[0] = t0i.scores[0] + pi.eta * ((RegressionTree)pi.trees.get(k)).predict(t0i.sparseSamples[0]);
                }
            }
            t0i.expScores[0] = Math.exp((double)t0i.scores[0]);
            t0i.alpha = t0i.expScores[0];
            t0i.age = pi.trees.size();
        }
        scaling[0] = 0.0;
        for (i = 0; i < this.numClasses; ++i) {
            scaling[0] = scaling[0] + t0[i].alpha;
        }
        for (i = 0; i < this.numClasses; ++i) {
            t0[i].alpha /= scaling[0];
        }
        for (int t = 1; t < n; ++t) {
            int i2;
            TrellisNode[] tt = trellis[t];
            TrellisNode[] tt1 = trellis[t - 1];
            for (i2 = 0; i2 < this.numClasses; ++i2) {
                TrellisNode tti = tt[i2];
                TreePotentialFunction pi = this.potentials[i2];
                tti.alpha = 0.0;
                for (int j = 0; j < this.numClasses; ++j) {
                    int k;
                    if (this.numFeatures <= 0) {
                        for (k = tti.age; k < pi.trees.size(); ++k) {
                            int n2 = j;
                            tti.scores[n2] = tti.scores[n2] + pi.eta * ((RegressionTree)pi.trees.get(k)).predict(tti.samples[j]);
                        }
                    } else {
                        for (k = tti.age; k < pi.trees.size(); ++k) {
                            int n3 = j;
                            tti.scores[n3] = tti.scores[n3] + pi.eta * ((RegressionTree)pi.trees.get(k)).predict(tti.sparseSamples[j]);
                        }
                    }
                    tti.expScores[j] = Math.exp((double)tti.scores[j]);
                    tti.alpha += tti.expScores[j] * tt1[j].alpha;
                }
                tti.age = pi.trees.size();
            }
            scaling[t] = 0.0;
            for (i2 = 0; i2 < this.numClasses; ++i2) {
                int n4 = t;
                scaling[n4] = scaling[n4] + tt[i2].alpha;
            }
            for (i2 = 0; i2 < this.numClasses; ++i2) {
                tt[i2].alpha /= scaling[t];
            }
        }
    }

    private void backward(TrellisNode[][] trellis) {
        int n = trellis.length - 1;
        TrellisNode[] tn = trellis[n];
        for (int i = 0; i < this.numClasses; ++i) {
            tn[i].beta = 1.0;
        }
        int t = n;
        while (t-- > 0) {
            int i;
            TrellisNode[] tt = trellis[t];
            TrellisNode[] tt1 = trellis[t + 1];
            for (int i2 = 0; i2 < this.numClasses; ++i2) {
                TrellisNode tti = tt[i2];
                tti.beta = 0.0;
                for (int j = 0; j < this.numClasses; ++j) {
                    tti.beta += tt1[j].expScores[i2] * tt1[j].beta;
                }
            }
            double sum = 0.0;
            for (i = 0; i < this.numClasses; ++i) {
                sum += tt[i].beta;
            }
            for (i = 0; i < this.numClasses; ++i) {
                tt[i].beta /= sum;
            }
        }
    }

    private TrellisNode[][] getTrellis(double[][] sequence) {
        TrellisNode[][] trellis = new TrellisNode[sequence.length][this.numClasses];
        TrellisNode[] t0 = trellis[0];
        for (int i = 0; i < this.numClasses; ++i) {
            t0[i] = new TrellisNode(false);
            t0[i].samples[0] = this.featureset(sequence[0], this.numClasses);
        }
        for (int t = 1; t < sequence.length; ++t) {
            trellis[t][0] = new TrellisNode(false);
            TrellisNode tt0 = trellis[t][0];
            for (int j = 0; j < this.numClasses; ++j) {
                tt0.samples[j] = this.featureset(sequence[t], j);
            }
            for (int i = 1; i < this.numClasses; ++i) {
                trellis[t][i] = new TrellisNode(false);
                TrellisNode tti = trellis[t][i];
                System.arraycopy(tt0.samples, 0, tti.samples, 0, this.numClasses);
            }
        }
        return trellis;
    }

    private TrellisNode[][] getTrellis(int[][] sequence) {
        TrellisNode[][] trellis = new TrellisNode[sequence.length][this.numClasses];
        TrellisNode[] t0 = trellis[0];
        for (int i = 0; i < this.numClasses; ++i) {
            t0[i] = new TrellisNode(true);
            t0[i].sparseSamples[0] = this.featureset(sequence[0], this.numClasses);
        }
        for (int t = 1; t < sequence.length; ++t) {
            trellis[t][0] = new TrellisNode(true);
            TrellisNode tt0 = trellis[t][0];
            for (int j = 0; j < this.numClasses; ++j) {
                tt0.sparseSamples[j] = this.featureset(sequence[t], j);
            }
            for (int i = 1; i < this.numClasses; ++i) {
                trellis[t][i] = new TrellisNode(true);
                TrellisNode tti = trellis[t][i];
                System.arraycopy(tt0.sparseSamples, 0, tti.sparseSamples, 0, this.numClasses);
            }
        }
        return trellis;
    }

    private void setTargets(TrellisNode[][] trellis, double[] scaling, int[] label) {
        int i;
        TrellisNode[] t0 = trellis[0];
        double normalizer = 0.0;
        for (i = 0; i < this.numClasses; ++i) {
            normalizer += t0[i].expScores[0] * t0[i].beta;
        }
        for (i = 0; i < this.numClasses; ++i) {
            t0[i].target[0] = label[0] == i ? 1.0 - t0[i].expScores[0] * t0[i].beta / normalizer : -t0[i].expScores[0] * t0[i].beta / normalizer;
        }
        for (int t = 1; t < label.length; ++t) {
            int i2;
            normalizer = 0.0;
            TrellisNode[] tt = trellis[t];
            TrellisNode[] tt1 = trellis[t - 1];
            for (i2 = 0; i2 < this.numClasses; ++i2) {
                normalizer += tt[i2].alpha * tt[i2].beta;
            }
            normalizer *= scaling[t];
            for (i2 = 0; i2 < this.numClasses; ++i2) {
                TrellisNode tti = tt[i2];
                for (int j = 0; j < this.numClasses; ++j) {
                    tti.target[j] = label[t] == i2 && label[t - 1] == j ? 1.0 - tti.expScores[j] * tt1[j].alpha * tti.beta / normalizer : -tti.expScores[j] * tt1[j].alpha * tti.beta / normalizer;
                }
            }
        }
    }

    public static class Trainer {
        private int numClasses;
        private int numFeatures = -1;
        private Attribute[] attributes;
        private int maxLeaves = 100;
        private double eta = 1.0;
        private int iters = 100;

        public Trainer(Attribute[] attributes, int numClasses) {
            if (numClasses < 2) {
                throw new IllegalArgumentException("Invalid number of classes: " + numClasses);
            }
            this.numClasses = numClasses;
            this.attributes = new Attribute[attributes.length + 1];
            System.arraycopy(attributes, 0, this.attributes, 0, attributes.length);
            String[] values = new String[numClasses + 1];
            for (int i = 0; i <= numClasses; ++i) {
                values[i] = Integer.toString(i);
            }
            this.attributes[attributes.length] = new NominalAttribute("Previous Position Label", values);
        }

        public Trainer(int numFeatures, int numClasses) {
            if (numFeatures < 2) {
                throw new IllegalArgumentException("Invalid number of features: " + numClasses);
            }
            if (numClasses < 2) {
                throw new IllegalArgumentException("Invalid number of classes: " + numClasses);
            }
            this.numFeatures = numFeatures;
            this.numClasses = numClasses;
        }

        public Trainer setMaxNodes(int maxLeaves) {
            if (maxLeaves < 2) {
                throw new IllegalArgumentException("Invalid number of leaf nodes: " + maxLeaves);
            }
            this.maxLeaves = maxLeaves;
            return this;
        }

        public Trainer setLearningRate(double eta) {
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + eta);
            }
            this.eta = eta;
            return this;
        }

        public Trainer setNumTrees(int iters) {
            if (iters < 1) {
                throw new IllegalArgumentException("Invalid number of iterations: " + iters);
            }
            this.iters = iters;
            return this;
        }

        public CRF train(double[][][] sequences, int[][] labels) {
            CRF crf = new CRF(this.numClasses, this.eta);
            double[][] scaling = new double[sequences.length][];
            TrellisNode[][][] trellis = new TrellisNode[sequences.length][][];
            for (int i = 0; i < sequences.length; ++i) {
                scaling[i] = new double[sequences[i].length];
                trellis[i] = crf.getTrellis(sequences[i]);
            }
            ArrayList<GradientTask> gradientTasks = new ArrayList<GradientTask>();
            for (int i = 0; i < sequences.length; ++i) {
                gradientTasks.add(new GradientTask(crf, trellis[i], scaling[i], labels[i]));
            }
            ArrayList<BoostingTask> boostingTasks = new ArrayList<BoostingTask>();
            for (int i = 0; i < this.numClasses; ++i) {
                boostingTasks.add(new BoostingTask(crf.potentials[i], trellis, i));
            }
            for (int iter = 0; iter < this.iters; ++iter) {
                try {
                    MulticoreExecutor.run(gradientTasks);
                    MulticoreExecutor.run(boostingTasks);
                    continue;
                }
                catch (Exception e) {
                    logger.error("Failed to train CRF on multi-core", (Throwable)e);
                }
            }
            return crf;
        }

        public CRF train(int[][][] sequences, int[][] labels) {
            CRF crf = new CRF(this.numFeatures, this.numClasses, this.eta);
            double[][] scaling = new double[sequences.length][];
            TrellisNode[][][] trellis = new TrellisNode[sequences.length][][];
            for (int i = 0; i < sequences.length; ++i) {
                scaling[i] = new double[sequences[i].length];
                trellis[i] = crf.getTrellis(sequences[i]);
            }
            ArrayList<GradientTask> gradientTasks = new ArrayList<GradientTask>();
            for (int i = 0; i < sequences.length; ++i) {
                gradientTasks.add(new GradientTask(crf, trellis[i], scaling[i], labels[i]));
            }
            ArrayList<BoostingTask> boostingTasks = new ArrayList<BoostingTask>();
            for (int i = 0; i < this.numClasses; ++i) {
                boostingTasks.add(new BoostingTask(crf.potentials[i], trellis, i));
            }
            for (int iter = 0; iter < this.iters; ++iter) {
                try {
                    MulticoreExecutor.run(gradientTasks);
                    MulticoreExecutor.run(boostingTasks);
                    continue;
                }
                catch (Exception e) {
                    logger.error("Failed to train CRF on multi-core", (Throwable)e);
                }
            }
            return crf;
        }

        class BoostingTask
        implements Callable<Object> {
            int i;
            TreePotentialFunction potential;
            TrellisNode[][][] trellis;
            RegressionTree.Trainer trainer;
            int[][] sparseX;
            double[][] x;
            double[] y;
            int[][] order;
            int[] samples;

            BoostingTask(TreePotentialFunction potential, TrellisNode[][][] trellis, int i) {
                int l;
                this.potential = potential;
                this.trellis = trellis;
                this.i = i;
                this.trainer = Trainer.this.numFeatures <= 0 ? new RegressionTree.Trainer(Trainer.this.attributes, Trainer.this.maxLeaves) : new RegressionTree.Trainer(Trainer.this.numFeatures + Trainer.this.numClasses + 1, Trainer.this.maxLeaves);
                int n = 0;
                for (l = 0; l < trellis.length; ++l) {
                    n += 1 + (trellis[l].length - 1) * Trainer.this.numClasses;
                }
                this.y = new double[n];
                if (Trainer.this.numFeatures <= 0) {
                    this.x = new double[n][];
                    int m = 0;
                    for (l = 0; l < trellis.length; ++l) {
                        TrellisNode[][] tl = trellis[l];
                        this.x[m++] = tl[0][i].samples[0];
                        for (int t = 1; t < trellis[l].length; ++t) {
                            TrellisNode tlti = tl[t][i];
                            for (int j = 0; j < Trainer.this.numClasses; ++j) {
                                this.x[m++] = tlti.samples[j];
                            }
                        }
                    }
                    int p = this.x[0].length;
                    double[] a = new double[n];
                    this.order = new int[p][];
                    for (int j = 0; j < p; ++j) {
                        if (!(Trainer.this.attributes[j] instanceof NumericAttribute)) continue;
                        for (int l2 = 0; l2 < n; ++l2) {
                            a[l2] = this.x[l2][j];
                        }
                        this.order[j] = QuickSort.sort((double[])a);
                    }
                } else {
                    this.sparseX = new int[n][];
                    int m = 0;
                    for (l = 0; l < trellis.length; ++l) {
                        TrellisNode[][] tl = trellis[l];
                        this.sparseX[m++] = tl[0][i].sparseSamples[0];
                        for (int t = 1; t < trellis[l].length; ++t) {
                            TrellisNode tlti = tl[t][i];
                            for (int j = 0; j < Trainer.this.numClasses; ++j) {
                                this.sparseX[m++] = tlti.sparseSamples[j];
                            }
                        }
                    }
                }
            }

            @Override
            public Object call() {
                int m = 0;
                for (int l = 0; l < this.trellis.length; ++l) {
                    TrellisNode[][] tl = this.trellis[l];
                    this.y[m++] = tl[0][this.i].target[0];
                    for (int t = 1; t < this.trellis[l].length; ++t) {
                        TrellisNode tlti = tl[t][this.i];
                        for (int j = 0; j < Trainer.this.numClasses; ++j) {
                            this.y[m++] = tlti.target[j];
                        }
                    }
                }
                if (this.x != null) {
                    RegressionTree tree = new RegressionTree(Trainer.this.attributes, this.x, this.y, Trainer.this.maxLeaves, 5, Trainer.this.attributes.length, this.order, this.samples, null);
                    this.potential.add(tree);
                } else {
                    RegressionTree tree = new RegressionTree(Trainer.this.numFeatures + Trainer.this.numClasses + 1, this.sparseX, this.y, Trainer.this.maxLeaves, 5, this.samples, null);
                    this.potential.add(tree);
                }
                return null;
            }
        }

        class GradientTask
        implements Callable<Object> {
            CRF crf;
            TrellisNode[][] trellis;
            double[] scaling;
            int[] label;

            GradientTask(CRF crf, TrellisNode[][] trellis, double[] scaling, int[] label) {
                this.crf = crf;
                this.trellis = trellis;
                this.scaling = scaling;
                this.label = label;
            }

            @Override
            public Object call() {
                this.crf.forward(this.trellis, this.scaling);
                this.crf.backward(this.trellis);
                this.crf.setTargets(this.trellis, this.scaling, this.label);
                return null;
            }
        }
    }

    class TrellisNode {
        double alpha = 1.0;
        double beta = 1.0;
        double[][] samples;
        int[][] sparseSamples;
        double[] target = new double[CRF.access$000(CRF.this)];
        double[] scores = new double[CRF.access$000(CRF.this)];
        double[] expScores = new double[CRF.access$000(CRF.this)];
        int age = 0;

        TrellisNode(boolean sparse) {
            if (sparse) {
                this.sparseSamples = new int[CRF.this.numClasses][];
            } else {
                this.samples = new double[CRF.this.numClasses][];
            }
        }
    }

    class TreePotentialFunction {
        private double eta;
        private List<RegressionTree> trees = new ArrayList<RegressionTree>();

        public TreePotentialFunction(double eta) {
            this.eta = eta;
        }

        public double f(double[] features) {
            double score = 0.0;
            for (RegressionTree tree : this.trees) {
                score += this.eta * tree.predict(features);
            }
            return score;
        }

        public double f(int[] features) {
            double score = 0.0;
            for (RegressionTree tree : this.trees) {
                score += this.eta * tree.predict(features);
            }
            return score;
        }

        public void add(RegressionTree tree) {
            this.trees.add(tree);
        }
    }
}

