/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.ClassifierTrainer;
import smile.classification.DecisionTree;
import smile.classification.SoftClassifier;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

public class AdaBoost
implements SoftClassifier<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(AdaBoost.class);
    private static final String INVALID_NUMBER_OF_TREES = "Invalid number of trees: ";
    private int k;
    private DecisionTree[] trees;
    private double[] alpha;
    private double[] error;
    private double[] importance;

    public AdaBoost(double[][] x, int[] y, int ntrees) {
        this(null, x, y, ntrees);
    }

    public AdaBoost(double[][] x, int[] y, int ntrees, int maxNodes) {
        this(null, x, y, ntrees, maxNodes);
    }

    public AdaBoost(Attribute[] attributes, double[][] x, int[] y, int ntrees) {
        this(attributes, x, y, ntrees, 2);
    }

    public AdaBoost(Attribute[] attributes, double[][] x, int[] y, int ntrees, int maxNodes) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (ntrees < 1) {
            throw new IllegalArgumentException(INVALID_NUMBER_OF_TREES + ntrees);
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes);
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i = 0; i < p; ++i) {
                attributes[i] = new NumericAttribute("V" + (i + 1));
            }
        }
        int[][] order = SmileUtils.sort(attributes, x);
        int n = x.length;
        int[] samples = new int[n];
        double[] w = new double[n];
        boolean[] err = new boolean[n];
        for (int i = 0; i < n; ++i) {
            w[i] = 1.0;
        }
        double guess = 1.0 / (double)this.k;
        double b = Math.log((double)(this.k - 1));
        int failures = 0;
        this.trees = new DecisionTree[ntrees];
        this.alpha = new double[ntrees];
        this.error = new double[ntrees];
        for (int t = 0; t < ntrees; ++t) {
            int[] rand;
            double W = Math.sum((double[])w);
            int i = 0;
            while (i < n) {
                int n2 = i++;
                w[n2] = w[n2] / W;
            }
            Arrays.fill(samples, 0);
            int[] nArray = rand = Math.random((double[])w, (int)n);
            int n3 = nArray.length;
            for (int j = 0; j < n3; ++j) {
                int s;
                int n4 = s = nArray[j];
                samples[n4] = samples[n4] + 1;
            }
            this.trees[t] = new DecisionTree(attributes, x, y, maxNodes, 1, x[0].length, DecisionTree.SplitRule.GINI, samples, order);
            for (int i2 = 0; i2 < n; ++i2) {
                err[i2] = this.trees[t].predict(x[i2]) != y[i2];
            }
            double e = 0.0;
            for (int i3 = 0; i3 < n; ++i3) {
                if (!err[i3]) continue;
                e += w[i3];
            }
            if (1.0 - e <= guess) {
                logger.error(String.format("Skip the weak classifier %d makes %.2f%% weighted error", t, 100.0 * e));
                if (++failures > 3) {
                    this.trees = Arrays.copyOf(this.trees, t);
                    this.alpha = Arrays.copyOf(this.alpha, t);
                    this.error = Arrays.copyOf(this.error, t);
                    break;
                }
                --t;
                continue;
            }
            failures = 0;
            this.error[t] = e;
            this.alpha[t] = Math.log((double)((1.0 - e) / Math.max((double)1.0E-10, (double)e))) + b;
            double a = Math.exp((double)this.alpha[t]);
            for (int i4 = 0; i4 < n; ++i4) {
                if (!err[i4]) continue;
                int n5 = i4;
                w[n5] = w[n5] * a;
            }
        }
        this.importance = new double[attributes.length];
        for (DecisionTree tree : this.trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n6 = i;
                this.importance[n6] = this.importance[n6] + imp[i];
            }
        }
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public void trim(int ntrees) {
        if (ntrees > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        if (ntrees < this.trees.length) {
            this.trees = Arrays.copyOf(this.trees, ntrees);
            this.alpha = Arrays.copyOf(this.alpha, ntrees);
            this.error = Arrays.copyOf(this.error, ntrees);
        }
    }

    @Override
    public int predict(double[] x) {
        double[] y = new double[this.k];
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].predict(x);
            y[n] = y[n] + this.alpha[i];
        }
        return Math.whichMax((double[])y);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        Arrays.fill(posteriori, 0.0);
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].predict(x);
            posteriori[n] = posteriori[n] + this.alpha[i];
        }
        double sum = Math.sum((double[])posteriori);
        int i = 0;
        while (i < this.k) {
            int n = i++;
            posteriori[n] = posteriori[n] / sum;
        }
        return Math.whichMax((double[])posteriori);
    }

    public double[] test(double[][] x, int[] y) {
        int T = this.trees.length;
        double[] accuracy = new double[T];
        int n = x.length;
        int[] label = new int[n];
        Accuracy measure = new Accuracy();
        if (this.k == 2) {
            double[] prediction = new double[n];
            for (int i = 0; i < T; ++i) {
                for (int j = 0; j < n; ++j) {
                    int n2 = j;
                    prediction[n2] = prediction[n2] + this.alpha[i] * (double)this.trees[i].predict(x[j]);
                    label[j] = prediction[j] > 0.0 ? 1 : 0;
                }
                accuracy[i] = measure.measure(y, label);
            }
        } else {
            double[][] prediction = new double[n][this.k];
            for (int i = 0; i < T; ++i) {
                for (int j = 0; j < n; ++j) {
                    double[] dArray = prediction[j];
                    int n3 = this.trees[i].predict(x[j]);
                    dArray[n3] = dArray[n3] + this.alpha[i];
                    label[j] = Math.whichMax((double[])prediction[j]);
                }
                accuracy[i] = measure.measure(y, label);
            }
        }
        return accuracy;
    }

    public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) {
        int T = this.trees.length;
        int m = measures.length;
        double[][] results = new double[T][m];
        int n = x.length;
        int[] label = new int[n];
        if (this.k == 2) {
            double[] prediction = new double[n];
            for (int i = 0; i < T; ++i) {
                int j;
                for (j = 0; j < n; ++j) {
                    int n2 = j;
                    prediction[n2] = prediction[n2] + this.alpha[i] * (double)this.trees[i].predict(x[j]);
                    label[j] = prediction[j] > 0.0 ? 1 : 0;
                }
                for (j = 0; j < m; ++j) {
                    results[i][j] = measures[j].measure(y, label);
                }
            }
        } else {
            double[][] prediction = new double[n][this.k];
            for (int i = 0; i < T; ++i) {
                int j;
                for (j = 0; j < n; ++j) {
                    double[] dArray = prediction[j];
                    int n3 = this.trees[i].predict(x[j]);
                    dArray[n3] = dArray[n3] + this.alpha[i];
                    label[j] = Math.whichMax((double[])prediction[j]);
                }
                for (j = 0; j < m; ++j) {
                    results[i][j] = measures[j].measure(y, label);
                }
            }
        }
        return results;
    }

    public DecisionTree[] getTrees() {
        return this.trees;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int ntrees = 500;
        private int maxNodes = 2;

        public Trainer() {
        }

        public Trainer(int ntrees) {
            if (ntrees < 1) {
                throw new IllegalArgumentException(AdaBoost.INVALID_NUMBER_OF_TREES + ntrees);
            }
            this.ntrees = ntrees;
        }

        public Trainer(Attribute[] attributes, int ntrees) {
            super(attributes);
            if (ntrees < 1) {
                throw new IllegalArgumentException(AdaBoost.INVALID_NUMBER_OF_TREES + ntrees);
            }
            this.ntrees = ntrees;
        }

        public Trainer setNumTrees(int ntrees) {
            if (ntrees < 1) {
                throw new IllegalArgumentException(AdaBoost.INVALID_NUMBER_OF_TREES + ntrees);
            }
            this.ntrees = ntrees;
            return this;
        }

        public Trainer setMaxNodes(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
            return this;
        }

        public AdaBoost train(double[][] x, int[] y) {
            return new AdaBoost(this.attributes, x, y, this.ntrees, this.maxNodes);
        }
    }
}

