/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.classification;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.classification.Classifier;
import de.lmu.ifi.dbs.elki.data.ClassLabel;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRef;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.utilities.documentation.Description;
import de.lmu.ifi.dbs.elki.utilities.documentation.Title;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.util.ArrayList;
import java.util.Collections;

@Title(value="kNN-classifier")
@Description(value="Lazy classifier classifies a given instance to the majority class of the k-nearest neighbors.")
public class KNNClassifier<O>
extends AbstractDistanceBasedAlgorithm<O, Result>
implements Classifier<O> {
    private static final Logging LOG = Logging.getLogger(KNNClassifier.class);
    protected int k;
    protected KNNQuery<O> knnq;
    protected Relation<? extends ClassLabel> labelrep;

    public KNNClassifier(DistanceFunction<? super O> distanceFunction, int n) {
        super(distanceFunction);
        this.k = n;
    }

    @Override
    public void buildClassifier(Database database, Relation<? extends ClassLabel> relation) {
        Relation relation2 = database.getRelation(this.getDistanceFunction().getInputTypeRestriction(), new Object[0]);
        DistanceQuery distanceQuery = database.getDistanceQuery(relation2, this.getDistanceFunction(), new Object[0]);
        this.knnq = database.getKNNQuery(distanceQuery, this.k);
        this.labelrep = relation;
    }

    @Override
    public ClassLabel classify(O o) {
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
        KNNList kNNList = this.knnq.getKNNForObject(o, this.k);
        DoubleDBIDListIter doubleDBIDListIter = kNNList.iter();
        while (doubleDBIDListIter.valid()) {
            tObjectIntHashMap.adjustOrPutValue((Object)this.labelrep.get(doubleDBIDListIter), 1, 1);
            doubleDBIDListIter.advance();
        }
        int n = Integer.MIN_VALUE;
        ClassLabel classLabel = null;
        TObjectIntIterator tObjectIntIterator = tObjectIntHashMap.iterator();
        while (tObjectIntIterator.hasNext()) {
            tObjectIntIterator.advance();
            if (tObjectIntIterator.value() <= n) continue;
            n = tObjectIntIterator.value();
            classLabel = (ClassLabel)tObjectIntIterator.key();
        }
        return classLabel;
    }

    public double[] classProbabilities(O o, ArrayList<ClassLabel> arrayList) {
        int n;
        int[] nArray = new int[arrayList.size()];
        KNNList kNNList = this.knnq.getKNNForObject(o, this.k);
        Object object = kNNList.iter();
        while (object.valid()) {
            n = Collections.binarySearch(arrayList, this.labelrep.get((DBIDRef)object));
            if (n >= 0) {
                int n2 = n;
                nArray[n2] = nArray[n2] + 1;
            }
            object.advance();
        }
        object = new double[arrayList.size()];
        for (n = 0; n < ((Object)object).length; ++n) {
            object[n] = (double)nArray[n] / (double)kNNList.size();
        }
        return object;
    }

    @Override
    public String model() {
        return "lazy learner - provides no model";
    }

    @Override
    @Deprecated
    public Result run(Database database) throws IllegalStateException {
        throw new AbortException("Classifiers cannot auto-run on a database, but need to be trained and can then predict.");
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        public static final OptionID K_ID = new OptionID("knnclassifier.k", "The number of neighbors to take into account for classification.");
        protected int k;

        @Override
        protected void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            IntParameter intParameter = (IntParameter)new IntParameter(K_ID, 1).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (parameterization.grab(intParameter)) {
                this.k = intParameter.intValue();
            }
        }

        @Override
        protected KNNClassifier<O> makeInstance() {
            return new KNNClassifier(this.distanceFunction, this.k);
        }
    }
}

