/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import smile.math.IntArrayList;
import smile.math.Math;
import smile.neighbor.KNNSearch;
import smile.neighbor.NearestNeighborSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;
import smile.sort.HeapSelect;
import smile.stat.distribution.GaussianDistribution;

public class MPLSH<E>
implements NearestNeighborSearch<double[], E>,
KNNSearch<double[], E>,
RNNSearch<double[], E> {
    ArrayList<double[]> keys;
    ArrayList<E> data;
    List<Hash> hash;
    int H;
    int d;
    int L;
    int k;
    double r;
    int[] c;
    int P = Integer.MAX_VALUE;
    boolean identicalExcluded = true;
    private List<PosterioriModel> model;

    public MPLSH(int d, int L, int k, double r) {
        this(d, L, k, r, 1017881);
    }

    public MPLSH(int d, int L, int k, double r, int H) {
        int i;
        if (d < 2) {
            throw new IllegalArgumentException("Invalid input space dimension: " + d);
        }
        if (L < 1) {
            throw new IllegalArgumentException("Invalid number of hash tables: " + L);
        }
        if (k < 1) {
            throw new IllegalArgumentException("Invalid number of random projections per hash value: " + k);
        }
        if (r <= 0.0) {
            throw new IllegalArgumentException("Invalid width of random projections: " + r);
        }
        if (H < 1) {
            throw new IllegalArgumentException("Invalid size of hash tables: " + H);
        }
        this.d = d;
        this.L = L;
        this.k = k;
        this.r = r;
        this.H = H;
        this.keys = new ArrayList();
        this.data = new ArrayList();
        this.c = new int[k];
        for (i = 0; i < this.c.length; ++i) {
            this.c[i] = Math.randomInt((int)this.P);
        }
        this.hash = new ArrayList<Hash>(L);
        for (i = 0; i < L; ++i) {
            this.hash.add(new Hash());
        }
    }

    public String toString() {
        return String.format("Multi-Probe LSH (L=%d, k=%d, H=%d, w=%.4f)", this.hash.size(), this.k, this.H, this.r);
    }

    public boolean isIdenticalExcluded() {
        return this.identicalExcluded;
    }

    public MPLSH<E> setIdenticalExcluded(boolean excluded) {
        this.identicalExcluded = excluded;
        return this;
    }

    public void put(double[] key, E value) {
        int index = this.keys.size();
        this.keys.add(key);
        this.data.add(value);
        for (Hash h : this.hash) {
            h.add(index, key, value);
        }
    }

    public void learn(RNNSearch<double[], double[]> range, double[][] samples, double radius) {
        this.learn(range, samples, radius, 2500);
    }

    public void learn(RNNSearch<double[], double[]> range, double[][] samples, double radius, int Nz) {
        this.learn(range, samples, radius, Nz, 0.2);
    }

    public void learn(RNNSearch<double[], double[]> range, double[][] samples, double radius, int Nz, double sigma) {
        int i;
        TrainSample[] training = new TrainSample[samples.length];
        for (i = 0; i < samples.length; ++i) {
            training[i] = new TrainSample();
            training[i].query = samples[i];
            training[i].neighbors = new ArrayList();
            ArrayList neighbors = new ArrayList();
            range.range(training[i].query, radius, neighbors);
            for (Neighbor n : neighbors) {
                training[i].neighbors.add(this.keys.get(n.index));
            }
        }
        this.model = new ArrayList<PosterioriModel>(this.hash.size());
        for (i = 0; i < this.hash.size(); ++i) {
            this.model.add(new PosterioriModel(this.hash.get(i), training, Nz, sigma));
        }
    }

    @Override
    public Neighbor<double[], E> nearest(double[] q) {
        return this.nearest(q, 0.95, 100);
    }

    public Neighbor<double[], E> nearest(double[] q, double recall, int T) {
        if (recall > 1.0 || recall < 0.0) {
            throw new IllegalArgumentException("Invalid recall: " + recall);
        }
        double alpha = 1.0 - Math.pow((double)(1.0 - recall), (double)(1.0 / (double)this.hash.size()));
        IntArrayList candidates = new IntArrayList();
        for (int i = 0; i < this.hash.size(); ++i) {
            IntArrayList buckets = this.model.get(i).getProbeSequence(q, alpha, T);
            for (int j = 0; j < buckets.size(); ++j) {
                int bucket = buckets.get(j);
                ArrayList<HashEntry> bin = this.hash.get((int)i).table[bucket % this.H];
                if (bin == null) continue;
                for (HashEntry e : bin) {
                    if (e.bucket != bucket || q == e.key && this.identicalExcluded) continue;
                    candidates.add(e.index);
                }
            }
        }
        Neighbor<Object, Object> neighbor = new Neighbor<Object, Object>(null, null, -1, Double.MAX_VALUE);
        int[] cand = candidates.toArray();
        Arrays.sort(cand);
        int prev = -1;
        for (int index : cand) {
            if (index == prev) continue;
            prev = index;
            double[] key = this.keys.get(index);
            double distance = Math.distance((double[])q, (double[])key);
            if (!(distance < neighbor.distance)) continue;
            neighbor.index = index;
            neighbor.distance = distance;
            neighbor.key = key;
            neighbor.value = this.data.get(index);
        }
        return neighbor;
    }

    @Override
    public Neighbor<double[], E>[] knn(double[] q, int k) {
        return this.knn(q, k, 0.95, 100);
    }

    public Neighbor<double[], E>[] knn(double[] q, int k, double recall, int T) {
        if (recall > 1.0 || recall < 0.0) {
            throw new IllegalArgumentException("Invalid recall: " + recall);
        }
        if (k < 1) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        double alpha = 1.0 - Math.pow((double)(1.0 - recall), (double)(1.0 / (double)this.hash.size()));
        int hit = 0;
        IntArrayList candidates = new IntArrayList();
        for (int i = 0; i < this.hash.size(); ++i) {
            IntArrayList buckets = this.model.get(i).getProbeSequence(q, alpha, T);
            for (int j = 0; j < buckets.size(); ++j) {
                int bucket = buckets.get(j);
                ArrayList<HashEntry> bin = this.hash.get((int)i).table[bucket % this.H];
                if (bin == null) continue;
                Object object = bin.iterator();
                while (object.hasNext()) {
                    HashEntry e = (HashEntry)object.next();
                    if (e.bucket != bucket || q == e.key && !this.identicalExcluded) continue;
                    candidates.add(e.index);
                }
            }
        }
        int[] cand = candidates.toArray();
        Arrays.sort(cand);
        Neighbor<Object, Object> neighbor = new Neighbor<Object, Object>(null, null, 0, Double.MAX_VALUE);
        Comparable[] neighbors = (Neighbor[])Array.newInstance(neighbor.getClass(), k);
        HeapSelect heap = new HeapSelect(neighbors);
        for (int i = 0; i < k; ++i) {
            heap.add(neighbor);
        }
        Object prev = -1;
        for (int index : cand) {
            if (index == prev) continue;
            prev = index;
            double[] key = this.keys.get(index);
            double dist = Math.distance((double[])q, (double[])key);
            if (!(dist < ((Neighbor)heap.peek()).distance)) continue;
            heap.add(new Neighbor<double[], E>(key, this.data.get(index), index, dist));
            ++hit;
        }
        heap.sort();
        if (hit < k) {
            Neighbor[] n2 = (Neighbor[])Array.newInstance(neighbor.getClass(), hit);
            int start = k - hit;
            for (int i = 0; i < hit; ++i) {
                n2[i] = neighbors[i + start];
            }
            neighbors = n2;
        }
        return neighbors;
    }

    @Override
    public void range(double[] q, double radius, List<Neighbor<double[], E>> neighbors) {
        this.range(q, radius, neighbors, 0.95, 100);
    }

    public void range(double[] q, double radius, List<Neighbor<double[], E>> neighbors, double recall, int T) {
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        if (recall > 1.0 || recall < 0.0) {
            throw new IllegalArgumentException("Invalid recall: " + recall);
        }
        double alpha = 1.0 - Math.pow((double)(1.0 - recall), (double)(1.0 / (double)this.hash.size()));
        for (int i = 0; i < this.hash.size(); ++i) {
            IntArrayList buckets = this.model.get(i).getProbeSequence(q, alpha, T);
            for (int j = 0; j < buckets.size(); ++j) {
                int bucket = buckets.get(j);
                ArrayList<HashEntry> bin = this.hash.get((int)i).table[bucket % this.H];
                if (bin == null) continue;
                for (HashEntry e : bin) {
                    double distance;
                    if (e.bucket != bucket || q == e.key && this.identicalExcluded || !((distance = Math.distance((double[])q, (double[])e.key)) <= radius)) continue;
                    boolean existed = false;
                    for (Neighbor neighbor : neighbors) {
                        if (e.index != neighbor.index) continue;
                        existed = true;
                        break;
                    }
                    if (existed) continue;
                    neighbors.add(new Neighbor(e.key, e.data, e.index, distance));
                }
            }
        }
    }

    class Probe
    implements Comparable<Probe> {
        int[] range;
        int[] bucket;
        int last;
        double prob;

        Probe(int[] range) {
            this.range = range;
            this.bucket = new int[range.length];
            this.last = 0;
        }

        boolean isShiftable() {
            if (this.bucket[this.last] != 1) {
                return false;
            }
            if (this.last + 1 >= this.bucket.length) {
                return false;
            }
            return this.range[this.last + 1] > 1;
        }

        Probe shift() {
            Probe p = new Probe(this.range);
            p.last = this.last;
            System.arraycopy(this.bucket, 0, p.bucket, 0, this.bucket.length);
            p.bucket[this.last] = 0;
            ++p.last;
            p.bucket[this.last] = 1;
            return p;
        }

        boolean isExpandable() {
            if (this.last + 1 >= this.bucket.length) {
                return false;
            }
            return this.range[this.last + 1] > 1;
        }

        Probe expand() {
            Probe p = new Probe(this.range);
            p.last = this.last;
            System.arraycopy(this.bucket, 0, p.bucket, 0, this.bucket.length);
            ++p.last;
            p.bucket[this.last] = 1;
            return p;
        }

        boolean isExtendable() {
            return this.bucket[this.last] + 1 < this.range[this.last];
        }

        Probe extend() {
            Probe p = new Probe(this.range);
            p.last = this.last;
            System.arraycopy(this.bucket, 0, p.bucket, 0, this.bucket.length);
            int n = this.last;
            p.bucket[n] = p.bucket[n] + 1;
            return p;
        }

        @Override
        public int compareTo(Probe o) {
            return (int)Math.signum((double)(this.prob - o.prob));
        }

        void setProb(PrZ[] pz) {
            this.prob = 1.0;
            for (int i = 0; i < this.bucket.length; ++i) {
                this.prob *= pz[i].prh[this.bucket[i]].pr;
            }
        }

        int hash(Hash hash, PrZ[] pz) {
            long r = 0L;
            for (int i = 0; i < MPLSH.this.k; ++i) {
                r += (long)(MPLSH.this.c[pz[i].m] * pz[i].prh[this.bucket[i]].u);
            }
            int h = (int)(r % (long)MPLSH.this.P);
            if (h < 0) {
                h += MPLSH.this.P;
            }
            return h;
        }
    }

    class PosterioriModel {
        Hash hash;
        PrH[][][] lookup;

        PosterioriModel(Hash hash, TrainSample[] samples, int Nz, double sigma) {
            this.hash = hash;
            HashValueParzenModel parzen = new HashValueParzenModel(hash, samples, sigma);
            this.lookup = new PrH[MPLSH.this.k][][];
            for (int m = 0; m < MPLSH.this.k; ++m) {
                int minh = (int)Math.floor((double)hash.umin[m]);
                int maxh = (int)Math.floor((double)hash.umax[m]);
                int size = Math.min((int)(maxh - minh + 1), (int)Nz);
                double delta = (double)(maxh - minh) / (double)size;
                this.lookup[m] = new PrH[size][];
                for (int n = 0; n < size; ++n) {
                    PrH prh;
                    int h0;
                    parzen.estimate(m, (double)minh + ((double)n + 0.5) * delta);
                    GaussianDistribution gaussian = new GaussianDistribution(parzen.mean, parzen.std);
                    ArrayList<PrH> probs = new ArrayList<PrH>();
                    int h = h0 = (int)Math.floor((double)parzen.mean);
                    while (true) {
                        prh = new PrH();
                        prh.u = ++h;
                        prh.pr = gaussian.cdf((double)(h + 1)) - gaussian.cdf((double)h);
                        if (prh.pr < 1.0E-7) break;
                        probs.add(prh);
                    }
                    h = h0 - 1;
                    while (true) {
                        prh = new PrH();
                        prh.u = --h;
                        prh.pr = gaussian.cdf((double)(h + 1)) - gaussian.cdf((double)h);
                        if (prh.pr < 1.0E-7) break;
                        probs.add(prh);
                    }
                    this.lookup[m][n] = probs.toArray(new PrH[probs.size()]);
                    Arrays.sort(this.lookup[m][n]);
                }
            }
        }

        IntArrayList getProbeSequence(double[] x, double recall, int T) {
            Object[] pz = new PrZ[MPLSH.this.k];
            for (int i = 0; i < MPLSH.this.k; ++i) {
                double h = this.hash.hash(x, i);
                double hmin = h - this.hash.umin[i];
                if (hmin < 0.0) {
                    hmin = 0.0;
                }
                if (h > this.hash.umax[i]) {
                    hmin = this.hash.umax[i] - this.hash.umin[i];
                }
                int qh = (int)(hmin * (double)this.lookup[i].length / (this.hash.umax[i] - this.hash.umin[i] + 1.0));
                pz[i] = new PrZ();
                ((PrZ)pz[i]).m = i;
                ((PrZ)pz[i]).prh = this.lookup[i][qh];
            }
            Arrays.sort(pz);
            IntArrayList seq = new IntArrayList();
            seq.add(this.hash.hash(x));
            int[] range = new int[MPLSH.this.k];
            for (int i = 0; i < MPLSH.this.k; ++i) {
                range[i] = ((PrZ)pz[i]).prh.length;
            }
            PriorityQueue<Probe> heap = new PriorityQueue<Probe>();
            heap.add(new Probe(range));
            ((Probe)heap.peek()).setProb((PrZ[])pz);
            double pr = ((Probe)heap.peek()).prob;
            seq.add(((Probe)heap.peek()).hash(this.hash, (PrZ[])pz));
            ((Probe)heap.peek()).bucket[0] = 0;
            ((Probe)heap.peek()).last = 0;
            ((Probe)heap.peek()).setProb((PrZ[])pz);
            while (!heap.isEmpty() && pr < recall && seq.size() < T) {
                Probe p2;
                Probe p = (Probe)heap.poll();
                seq.add(p.hash(this.hash, (PrZ[])pz));
                pr += p.prob;
                if (p.isShiftable()) {
                    p2 = p.shift();
                    p2.setProb((PrZ[])pz);
                    heap.offer(p2);
                }
                if (p.isExpandable()) {
                    p2 = p.expand();
                    p2.setProb((PrZ[])pz);
                    heap.offer(p2);
                }
                if (!p.isExtendable()) continue;
                p2 = p.extend();
                p2.setProb((PrZ[])pz);
                heap.offer(p2);
            }
            return seq;
        }
    }

    class HashValueParzenModel {
        GaussianDistribution gaussian;
        NeighborHashValueModel[] neighborHashValueModels;
        double mean;
        double std;

        HashValueParzenModel(Hash hash, TrainSample[] samples, double sigma) {
            this.gaussian = new GaussianDistribution(0.0, sigma);
            int n = 0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i].neighbors.size() <= 1) continue;
                ++n;
            }
            this.neighborHashValueModels = new NeighborHashValueModel[n];
            int l = 0;
            for (TrainSample sample : samples) {
                if (sample.neighbors.size() <= 1) continue;
                double[] H = new double[MPLSH.this.k];
                double[] mu = new double[MPLSH.this.k];
                double[] var = new double[MPLSH.this.k];
                for (int i = 0; i < MPLSH.this.k; ++i) {
                    H[i] = hash.hash(sample.query, i);
                    double sum = 0.0;
                    double sumsq = 0.0;
                    for (double[] v : sample.neighbors) {
                        double h = hash.hash(v, i);
                        sum += h;
                        sumsq += h * h;
                    }
                    mu[i] = sum / (double)sample.neighbors.size();
                    var[i] = sumsq / (double)sample.neighbors.size() - mu[i] * mu[i];
                }
                this.neighborHashValueModels[l++] = new NeighborHashValueModel(H, mu, var);
            }
        }

        void estimate(int m, double h) {
            int i;
            double mm = 0.0;
            double vv = 0.0;
            double ss = 0.0;
            for (i = 0; i < this.neighborHashValueModels.length; ++i) {
                double k = this.gaussian.p(this.neighborHashValueModels[i].H[m] - h);
                mm += k * this.neighborHashValueModels[i].mean[m];
                vv += k * this.neighborHashValueModels[i].var[m];
                ss += k;
            }
            if (ss > 1.0E-7) {
                this.mean = mm / ss;
                this.std = Math.sqrt((double)(vv / ss));
            } else {
                this.mean = h;
                this.std = 0.0;
            }
            if (this.std < 1.0E-5) {
                this.std = 0.0;
                for (i = 0; i < this.neighborHashValueModels.length; ++i) {
                    this.std += this.neighborHashValueModels[i].var[m];
                }
                this.std = Math.sqrt((double)(this.std / (double)this.neighborHashValueModels.length));
            }
        }
    }

    static class NeighborHashValueModel {
        double[] H;
        double[] mean;
        double[] var;

        NeighborHashValueModel(double[] H, double[] mean, double[] var) {
            this.H = H;
            this.mean = mean;
            this.var = var;
        }
    }

    static class TrainSample {
        double[] query;
        ArrayList<double[]> neighbors;

        TrainSample() {
        }
    }

    static class PrZ
    implements Comparable<PrZ> {
        int m;
        PrH[] prh;

        PrZ() {
        }

        @Override
        public int compareTo(PrZ o) {
            return this.prh[0].compareTo(o.prh[0]);
        }
    }

    static class PrH
    implements Comparable<PrH> {
        int u;
        double pr;

        PrH() {
        }

        @Override
        public int compareTo(PrH o) {
            return (int)Math.signum((double)(o.pr - this.pr));
        }
    }

    class Hash {
        double[][] a;
        double[] b;
        double[] umin;
        double[] umax;
        ArrayList<HashEntry>[] table;

        Hash() {
            this.a = new double[MPLSH.this.k][MPLSH.this.d];
            this.b = new double[MPLSH.this.k];
            this.umin = new double[MPLSH.this.k];
            this.umax = new double[MPLSH.this.k];
            Arrays.fill(this.umin, Double.POSITIVE_INFINITY);
            Arrays.fill(this.umax, Double.NEGATIVE_INFINITY);
            GaussianDistribution gaussian = GaussianDistribution.getInstance();
            for (int i = 0; i < MPLSH.this.k; ++i) {
                for (int j = 0; j < MPLSH.this.d; ++j) {
                    this.a[i][j] = gaussian.rand();
                }
                this.b[i] = Math.random((double)0.0, (double)MPLSH.this.r);
            }
            ArrayList list = new ArrayList();
            this.table = (ArrayList[])Array.newInstance(list.getClass(), MPLSH.this.H);
        }

        double hash(double[] x, int m) {
            double g = this.b[m];
            for (int j = 0; j < MPLSH.this.d; ++j) {
                g += this.a[m][j] * x[j];
            }
            return g / MPLSH.this.r;
        }

        int hash(double[] x) {
            long g = 0L;
            for (int i = 0; i < MPLSH.this.k; ++i) {
                double gi = this.hash(x, i);
                if (gi < this.umin[i]) {
                    this.umin[i] = gi;
                }
                if (gi > this.umax[i]) {
                    this.umax[i] = gi;
                }
                g += (long)(MPLSH.this.c[i] * (int)Math.floor((double)gi));
            }
            int h = (int)(g % (long)MPLSH.this.P);
            if (h < 0) {
                h += MPLSH.this.P;
            }
            return h;
        }

        void add(int index, double[] x, E data) {
            int bucket = this.hash(x);
            int i = bucket % MPLSH.this.H;
            if (this.table[i] == null) {
                this.table[i] = new ArrayList();
            }
            this.table[i].add(new HashEntry(bucket, index, x, data));
        }
    }

    class HashEntry {
        int bucket;
        int index;
        double[] key;
        E data;

        HashEntry(int bucket, int index, double[] x, E data) {
            this.bucket = bucket;
            this.index = index;
            this.key = x;
            this.data = data;
        }
    }
}

