/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import java.util.ArrayList;
import java.util.List;
import smile.math.Math;
import smile.stat.distribution.MultivariateExponentialFamily;
import smile.stat.distribution.MultivariateMixture;

public class MultivariateExponentialFamilyMixture
extends MultivariateMixture {
    MultivariateExponentialFamilyMixture() {
    }

    public MultivariateExponentialFamilyMixture(List<MultivariateMixture.Component> mixture) {
        super(mixture);
        for (MultivariateMixture.Component component : mixture) {
            if (component.distribution instanceof MultivariateExponentialFamily) continue;
            throw new IllegalArgumentException("Component " + component + " is not of multivariate exponential family.");
        }
    }

    public MultivariateExponentialFamilyMixture(List<MultivariateMixture.Component> mixture, double[][] data) {
        this(mixture);
        this.EM(this.components, data);
    }

    double EM(List<MultivariateMixture.Component> mixture, double[][] x) {
        return this.EM(mixture, x, 0.2);
    }

    double EM(List<MultivariateMixture.Component> mixture, double[][] x, double gamma) {
        return this.EM(mixture, x, gamma, Integer.MAX_VALUE);
    }

    double EM(List<MultivariateMixture.Component> components, double[][] x, double gamma, int maxIter) {
        if (x.length < components.size() / 2) {
            throw new IllegalArgumentException("Too many components");
        }
        if (gamma < 0.0 || gamma > 0.2) {
            throw new IllegalArgumentException("Invalid regularization factor gamma.");
        }
        int n = x.length;
        int m = components.size();
        double[][] posteriori = new double[m][n];
        double L = 0.0;
        for (double[] xi : x) {
            double p = 0.0;
            for (MultivariateMixture.Component c : components) {
                p += c.priori * c.distribution.p(xi);
            }
            if (!(p > 0.0)) continue;
            L += Math.log(p);
        }
        for (int iter = 0; iter < maxIter; ++iter) {
            int i;
            for (int i2 = 0; i2 < m; ++i2) {
                MultivariateMixture.Component c = components.get(i2);
                for (int j = 0; j < n; ++j) {
                    posteriori[i2][j] = c.priori * c.distribution.p(x[j]);
                }
            }
            for (int j = 0; j < n; ++j) {
                int i3;
                double p = 0.0;
                for (i3 = 0; i3 < m; ++i3) {
                    p += posteriori[i3][j];
                }
                for (i3 = 0; i3 < m; ++i3) {
                    double[] dArray = posteriori[i3];
                    int n2 = j;
                    dArray[n2] = dArray[n2] / p;
                }
                if (!(gamma > 0.0)) continue;
                for (i3 = 0; i3 < m; ++i3) {
                    double[] dArray = posteriori[i3];
                    int n3 = j;
                    dArray[n3] = dArray[n3] * (1.0 + gamma * Math.log2(posteriori[i3][j]));
                    if (!Double.isNaN(posteriori[i3][j]) && !(posteriori[i3][j] < 0.0)) continue;
                    posteriori[i3][j] = 0.0;
                }
            }
            ArrayList<MultivariateMixture.Component> newConfig = new ArrayList<MultivariateMixture.Component>();
            for (int i4 = 0; i4 < m; ++i4) {
                newConfig.add(((MultivariateExponentialFamily)((Object)components.get((int)i4).distribution)).M(x, posteriori[i4]));
            }
            double sumAlpha = 0.0;
            for (i = 0; i < m; ++i) {
                sumAlpha += ((MultivariateMixture.Component)newConfig.get((int)i)).priori;
            }
            for (i = 0; i < m; ++i) {
                ((MultivariateMixture.Component)newConfig.get((int)i)).priori /= sumAlpha;
            }
            double newL = 0.0;
            for (Object xi : (Object)x) {
                double p = 0.0;
                for (MultivariateMixture.Component c : newConfig) {
                    p += c.priori * c.distribution.p((double[])xi);
                }
                if (!(p > 0.0)) continue;
                newL += Math.log(p);
            }
            if (!(newL > L)) break;
            L = newL;
            components.clear();
            components.addAll(newConfig);
        }
        return L;
    }
}

