/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.KMeans;
import smile.math.Math;
import smile.sort.QuickSort;
import smile.stat.distribution.GaussianDistribution;

public class GMeans
extends KMeans
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(GMeans.class);

    public GMeans(double[][] data, int kmax) {
        int i;
        if (kmax < 2) {
            throw new IllegalArgumentException("Invalid parameter kmax = " + kmax);
        }
        int n = data.length;
        int d = data[0].length;
        this.k = 1;
        this.size = new int[this.k];
        this.size[0] = n;
        this.y = new int[n];
        this.centroids = new double[this.k][d];
        for (i = 0; i < n; ++i) {
            for (int j = 0; j < d; ++j) {
                double[] dArray = this.centroids[0];
                int n2 = j;
                dArray[n2] = dArray[n2] + data[i][j];
            }
        }
        int j = 0;
        while (j < d) {
            double[] dArray = this.centroids[0];
            int n3 = j++;
            dArray[n3] = dArray[n3] / (double)n;
        }
        this.distortion = 0.0;
        for (i = 0; i < n; ++i) {
            this.distortion += Math.squaredDistance((double[])data[i], (double[])this.centroids[0]);
        }
        logger.info(String.format("G-Means distortion with %d clusters: %.5f", this.k, this.distortion));
        BBDTree bbd = new BBDTree(data);
        while (this.k < kmax) {
            ArrayList<double[]> centers = new ArrayList<double[]>();
            double[] score = new double[this.k];
            KMeans[] kmeans = new KMeans[this.k];
            for (int i2 = 0; i2 < this.k; ++i2) {
                if (this.size[i2] < 25) {
                    logger.info("Cluster {} too small to split: {} samples", (Object)i2, (Object)this.size[i2]);
                    continue;
                }
                double[][] subset = new double[this.size[i2]][];
                int l = 0;
                for (int j2 = 0; j2 < n; ++j2) {
                    if (this.y[j2] != i2) continue;
                    subset[l++] = data[j2];
                }
                kmeans[i2] = new KMeans(subset, 2, 100, 4);
                double[] v = new double[d];
                for (int j3 = 0; j3 < d; ++j3) {
                    v[j3] = kmeans[i2].centroids[0][j3] - kmeans[i2].centroids[1][j3];
                }
                double vp = Math.dot((double[])v, (double[])v);
                double[] x = new double[this.size[i2]];
                for (int j4 = 0; j4 < x.length; ++j4) {
                    x[j4] = Math.dot((double[])subset[j4], (double[])v) / vp;
                }
                Math.standardize((double[])x);
                score[i2] = GMeans.AndersonDarling(x);
                logger.info(String.format("Cluster %3d\tAnderson-Darling adjusted test statistic: %3.4f", i2, score[i2]));
            }
            int[] index = QuickSort.sort((double[])score);
            for (int i3 = 0; i3 < this.k; ++i3) {
                if (!(score[index[i3]] <= 1.8692)) continue;
                centers.add(this.centroids[index[i3]]);
            }
            int m = centers.size();
            int i4 = this.k;
            while (--i4 >= 0) {
                if (!(score[i4] > 1.8692)) continue;
                if (centers.size() + i4 - m + 1 < kmax) {
                    logger.info("Split cluster {}", (Object)index[i4]);
                    centers.add(kmeans[index[i4]].centroids[0]);
                    centers.add(kmeans[index[i4]].centroids[1]);
                    continue;
                }
                centers.add(this.centroids[index[i4]]);
            }
            if (centers.size() == this.k) break;
            this.k = centers.size();
            double[][] sums = new double[this.k][d];
            this.size = new int[this.k];
            this.centroids = new double[this.k][];
            for (int i5 = 0; i5 < this.k; ++i5) {
                this.centroids[i5] = (double[])centers.get(i5);
            }
            this.distortion = Double.MAX_VALUE;
            for (int iter = 0; iter < 100; ++iter) {
                double newDistortion = bbd.clustering(this.centroids, sums, this.size, this.y);
                for (int i6 = 0; i6 < this.k; ++i6) {
                    if (this.size[i6] <= 0) continue;
                    for (int j5 = 0; j5 < d; ++j5) {
                        this.centroids[i6][j5] = sums[i6][j5] / (double)this.size[i6];
                    }
                }
                if (this.distortion <= newDistortion) break;
                this.distortion = newDistortion;
            }
            logger.info(String.format("G-Means distortion with %d clusters: %.5f%n", this.k, this.distortion));
        }
    }

    private static double AndersonDarling(double[] x) {
        int n = x.length;
        Arrays.sort(x);
        for (int i = 0; i < n; ++i) {
            x[i] = GaussianDistribution.getInstance().cdf(x[i]);
            if (x[i] == 0.0) {
                x[i] = 1.0E-7;
            }
            if (x[i] != 1.0) continue;
            x[i] = 0.9999999;
        }
        double A = 0.0;
        for (int i = 0; i < n; ++i) {
            A -= (double)(2 * i + 1) * (Math.log((double)x[i]) + Math.log((double)(1.0 - x[n - i - 1])));
        }
        A = A / (double)n - (double)n;
        return A *= 1.0 + 4.0 / (double)n - 25.0 / (double)(n * n);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("G-Means distortion: %.5f%n", this.distortion));
        sb.append(String.format("Clusters of %d data points of dimension %d:%n", this.y.length, this.centroids[0].length));
        for (int i = 0; i < this.k; ++i) {
            int r = (int)Math.round((double)(1000.0 * (double)this.size[i] / (double)this.y.length));
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", i, this.size[i], r / 10, r % 10));
        }
        return sb.toString();
    }
}

