/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.KMeans;
import smile.math.Math;
import smile.sort.QuickSort;

public class XMeans
extends KMeans
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(XMeans.class);
    private static final double LOG2PI = Math.log((double)(java.lang.Math.PI * 2));

    public XMeans(double[][] data, int kmax) {
        if (kmax < 2) {
            throw new IllegalArgumentException("Invalid parameter kmax = " + kmax);
        }
        int n = data.length;
        int d = data[0].length;
        this.k = 1;
        this.size = new int[this.k];
        this.size[0] = n;
        this.y = new int[n];
        this.centroids = new double[this.k][d];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < d; ++j) {
                double[] dArray = this.centroids[0];
                int n2 = j;
                dArray[n2] = dArray[n2] + data[i][j];
            }
        }
        int j = 0;
        while (j < d) {
            double[] dArray = this.centroids[0];
            int n3 = j++;
            dArray[n3] = dArray[n3] / (double)n;
        }
        double[] wcss = new double[this.k];
        for (int i = 0; i < n; ++i) {
            wcss[0] = wcss[0] + Math.squaredDistance((double[])data[i], (double[])this.centroids[0]);
        }
        this.distortion = wcss[0];
        logger.info(String.format("X-Means distortion with %d clusters: %.5f", this.k, this.distortion));
        BBDTree bbd = new BBDTree(data);
        while (this.k < kmax) {
            int i;
            ArrayList<double[]> centers = new ArrayList<double[]>();
            double[] score = new double[this.k];
            KMeans[] kmeans = new KMeans[this.k];
            for (int i2 = 0; i2 < this.k; ++i2) {
                if (this.size[i2] < 25) {
                    logger.info("Cluster {} too small to split: {} samples", (Object)i2, (Object)this.size[i2]);
                    continue;
                }
                double[][] subset = new double[this.size[i2]][];
                int l = 0;
                for (int j2 = 0; j2 < n; ++j2) {
                    if (this.y[j2] != i2) continue;
                    subset[l++] = data[j2];
                }
                kmeans[i2] = new KMeans(subset, 2, 100, 4);
                double newBIC = this.bic(2, this.size[i2], d, kmeans[i2].distortion, kmeans[i2].size);
                double oldBIC = this.bic(this.size[i2], d, wcss[i2]);
                score[i2] = newBIC - oldBIC;
                logger.info(String.format("Cluster %3d\tBIC: %.5f\tBIC after split: %.5f\timprovement: %.5f", i2, oldBIC, newBIC, score[i2]));
            }
            int[] index = QuickSort.sort((double[])score);
            for (int i3 = 0; i3 < this.k; ++i3) {
                if (!(score[index[i3]] <= 0.0)) continue;
                centers.add(this.centroids[index[i3]]);
            }
            int m = centers.size();
            int i4 = this.k;
            while (--i4 >= 0) {
                if (!(score[i4] > 0.0)) continue;
                if (centers.size() + i4 - m + 1 < kmax) {
                    logger.info("Split cluster {}", (Object)index[i4]);
                    centers.add(kmeans[index[i4]].centroids[0]);
                    centers.add(kmeans[index[i4]].centroids[1]);
                    continue;
                }
                centers.add(this.centroids[index[i4]]);
            }
            if (centers.size() == this.k) break;
            this.k = centers.size();
            double[][] sums = new double[this.k][d];
            this.size = new int[this.k];
            this.centroids = new double[this.k][];
            for (i = 0; i < this.k; ++i) {
                this.centroids[i] = (double[])centers.get(i);
            }
            this.distortion = Double.MAX_VALUE;
            for (int iter = 0; iter < 100; ++iter) {
                double newDistortion = bbd.clustering(this.centroids, sums, this.size, this.y);
                for (int i5 = 0; i5 < this.k; ++i5) {
                    if (this.size[i5] <= 0) continue;
                    for (int j3 = 0; j3 < d; ++j3) {
                        this.centroids[i5][j3] = sums[i5][j3] / (double)this.size[i5];
                    }
                }
                if (this.distortion <= newDistortion) break;
                this.distortion = newDistortion;
            }
            wcss = new double[this.k];
            for (i = 0; i < n; ++i) {
                int n4 = this.y[i];
                wcss[n4] = wcss[n4] + Math.squaredDistance((double[])data[i], (double[])this.centroids[this.y[i]]);
            }
            logger.info(String.format("X-Means distortion with %d clusters: %.5f", this.k, this.distortion));
        }
    }

    private double bic(int n, int d, double distortion) {
        double variance = distortion / (double)(n - 1);
        double p1 = (double)(-n) * LOG2PI;
        double p2 = (double)(-n * d) * Math.log((double)variance);
        double p3 = -(n - 1);
        double L = (p1 + p2 + p3) / 2.0;
        int numParameters = d + 1;
        return L - 0.5 * (double)numParameters * Math.log((double)n);
    }

    private double bic(int k, int n, int d, double distortion, int[] clusterSize) {
        double variance = distortion / (double)(n - k);
        double L = 0.0;
        for (int i = 0; i < k; ++i) {
            L += XMeans.logLikelihood(k, n, clusterSize[i], d, variance);
        }
        int numParameters = k + k * d;
        return L - 0.5 * (double)numParameters * Math.log((double)n);
    }

    private static double logLikelihood(int k, int n, int ni, int d, double variance) {
        double p1 = (double)(-ni) * LOG2PI;
        double p2 = (double)(-ni * d) * Math.log((double)variance);
        double p3 = -(ni - k);
        double p4 = (double)ni * Math.log((double)ni);
        double p5 = (double)(-ni) * Math.log((double)n);
        double loglike = (p1 + p2 + p3) / 2.0 + p4 + p5;
        return loglike;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("X-Means distortion: %.5f%n", this.distortion));
        sb.append(String.format("Clusters of %d data points of dimension %d:%n", this.y.length, this.centroids[0].length));
        for (int i = 0; i < this.k; ++i) {
            int r = (int)Math.round((double)(1000.0 * (double)this.size[i] / (double)this.y.length));
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", i, this.size[i], r / 10, r % 10));
        }
        return sb.toString();
    }
}

