/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.benchmark;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRange;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.utilities.Util;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;

public class KNNBenchmarkAlgorithm<O>
extends AbstractDistanceBasedAlgorithm<O, Result> {
    private static final Logging LOG = Logging.getLogger(KNNBenchmarkAlgorithm.class);
    protected int k = 10;
    protected DatabaseConnection queries = null;
    protected double sampling = -1.0;
    protected RandomFactory random;

    public KNNBenchmarkAlgorithm(DistanceFunction<? super O> distanceFunction, int n, DatabaseConnection databaseConnection, double d, RandomFactory randomFactory) {
        super(distanceFunction);
        this.k = n;
        this.queries = databaseConnection;
        this.sampling = d;
        this.random = randomFactory;
    }

    public Result run(Database database, Relation<O> relation) {
        DistanceQuery<O> distanceQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        KNNQuery<Object> kNNQuery = database.getKNNQuery(distanceQuery, this.k);
        if (this.queries == null) {
            DBIDs dBIDs = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
            FiniteProgress finiteProgress = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", dBIDs.size(), LOG) : null;
            int n = 0;
            MeanVariance meanVariance = new MeanVariance();
            MeanVariance meanVariance2 = new MeanVariance();
            DBIDIter dBIDIter = dBIDs.iter();
            while (dBIDIter.valid()) {
                KNNList kNNList = kNNQuery.getKNNForDBID(dBIDIter, this.k);
                int n2 = 0;
                DoubleDBIDListIter doubleDBIDListIter = kNNList.iter();
                while (doubleDBIDListIter.valid()) {
                    n2 += DBIDUtil.asInteger(doubleDBIDListIter);
                    doubleDBIDListIter.advance();
                }
                n = Util.mixHashCodes(n, n2);
                meanVariance.put(kNNList.size());
                meanVariance2.put(kNNList.getKNNDistance());
                LOG.incrementProcessed(finiteProgress);
                dBIDIter.advance();
            }
            LOG.ensureCompleted(finiteProgress);
            if (LOG.isStatistics()) {
                LOG.statistics("Result hashcode: " + n);
                LOG.statistics("Mean number of results: " + meanVariance.getMean() + " +- " + meanVariance.getNaiveStddev());
                if (meanVariance2.getCount() > 0.0) {
                    LOG.statistics("Mean k-distance: " + meanVariance2.getMean() + " +- " + meanVariance2.getNaiveStddev());
                }
            }
        } else {
            TypeInformation typeInformation = this.getDistanceFunction().getInputTypeRestriction();
            MultipleObjectsBundle multipleObjectsBundle = this.queries.loadData();
            int n = -1;
            for (int i = 0; i < multipleObjectsBundle.metaLength(); ++i) {
                if (!typeInformation.isAssignableFromType(multipleObjectsBundle.meta(i))) continue;
                n = i;
                break;
            }
            if (n < 0) {
                throw new AbortException("No compatible data type in query input was found. Expected: " + typeInformation.toString());
            }
            DBIDRange dBIDRange = DBIDUtil.generateStaticDBIDRange(multipleObjectsBundle.dataLength());
            DBIDs dBIDs = DBIDUtil.randomSample((DBIDs)dBIDRange, this.sampling, this.random);
            FiniteProgress finiteProgress = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", dBIDs.size(), LOG) : null;
            int n3 = 0;
            MeanVariance meanVariance = new MeanVariance();
            MeanVariance meanVariance3 = new MeanVariance();
            DBIDIter dBIDIter = dBIDs.iter();
            while (dBIDIter.valid()) {
                int n4 = dBIDRange.binarySearch(dBIDIter);
                assert (n4 >= 0);
                Object object = multipleObjectsBundle.data(n4, n);
                KNNList kNNList = kNNQuery.getKNNForObject(object, this.k);
                int n5 = 0;
                DoubleDBIDListIter doubleDBIDListIter = kNNList.iter();
                while (doubleDBIDListIter.valid()) {
                    n5 += DBIDUtil.asInteger(doubleDBIDListIter);
                    doubleDBIDListIter.advance();
                }
                n3 = Util.mixHashCodes(n3, n5);
                meanVariance.put(kNNList.size());
                meanVariance3.put(kNNList.getKNNDistance());
                LOG.incrementProcessed(finiteProgress);
                dBIDIter.advance();
            }
            LOG.ensureCompleted(finiteProgress);
            if (LOG.isStatistics()) {
                LOG.statistics("Result hashcode: " + n3);
                LOG.statistics("Mean number of results: " + meanVariance.getMean() + " +- " + meanVariance.getNaiveStddev());
                if (meanVariance3.getCount() > 0.0) {
                    LOG.statistics("Mean k-distance: " + meanVariance3.getMean() + " +- " + meanVariance3.getNaiveStddev());
                }
            }
        }
        return null;
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(this.getDistanceFunction().getInputTypeRestriction());
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        public static final OptionID K_ID = new OptionID("knnbench.k", "Number of neighbors to retreive for kNN benchmarking.");
        public static final OptionID QUERY_ID = new OptionID("knnbench.query", "Data source for the queries. If not set, the queries are taken from the database.");
        public static final OptionID SAMPLING_ID = new OptionID("knnbench.sampling", "Sampling size parameter. If the value is less or equal 1, it is assumed to be the relative share. Larger values will be interpreted as integer sizes. By default, all data will be used.");
        public static final OptionID RANDOM_ID = new OptionID("knnbench.random", "Random generator for sampling.");
        protected int k = 10;
        protected DatabaseConnection queries = null;
        protected double sampling = -1.0;
        protected RandomFactory random;

        @Override
        protected void makeOptions(Parameterization parameterization) {
            RandomParameter randomParameter;
            super.makeOptions(parameterization);
            IntParameter intParameter = new IntParameter(K_ID);
            if (parameterization.grab(intParameter)) {
                this.k = intParameter.intValue();
            }
            ObjectParameter objectParameter = new ObjectParameter(QUERY_ID, DatabaseConnection.class);
            objectParameter.setOptional(true);
            if (parameterization.grab(objectParameter)) {
                this.queries = (DatabaseConnection)objectParameter.instantiateClass(parameterization);
            }
            DoubleParameter doubleParameter = new DoubleParameter(SAMPLING_ID);
            doubleParameter.setOptional(true);
            if (parameterization.grab(doubleParameter)) {
                this.sampling = doubleParameter.doubleValue();
            }
            if (parameterization.grab(randomParameter = new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT))) {
                this.random = (RandomFactory)randomParameter.getValue();
            }
        }

        @Override
        protected KNNBenchmarkAlgorithm<O> makeInstance() {
            return new KNNBenchmarkAlgorithm(this.distanceFunction, this.k, this.queries, this.sampling, this.random);
        }
    }
}

