Java tutorial
/* * (c) 2005 David B. Bracewell * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.davidbracewell.ml.classification.lazy; import com.davidbracewell.collection.InvertedIndex; import com.davidbracewell.collection.Sorting; import com.davidbracewell.math.distance.DistanceMeasure; import com.davidbracewell.math.linear.VectorMap; import com.davidbracewell.ml.Feature; import com.davidbracewell.ml.Instance; import com.davidbracewell.ml.classification.ClassificationModel; import com.davidbracewell.ml.classification.ClassificationResult; import com.davidbracewell.tuple.Pair; import com.google.common.collect.MinMaxPriorityQueue; import com.google.common.collect.Ordering; /** * @author David B. Bracewell */ public class SparseKNN extends ClassificationModel { private static final long serialVersionUID = 1670966911919829066L; int K; InvertedIndex<Instance, Feature> index; DistanceMeasure distanceMeasure; @Override protected ClassificationResult classifyImpl(Instance instance) { final MinMaxPriorityQueue<Pair<Double, Double>> neighbors = MinMaxPriorityQueue .orderedBy(Ordering.from(Sorting.<Double, Double>mapEntryComparator(false, true))).maximumSize(K) .create(); for (Instance inst : index.query(instance)) { double distance = distanceMeasure.calculate(inst, instance, VectorMap.VALID_VALUES.FINITE); neighbors.add(Pair.of(inst.getTargetValue(), distance)); } double[] p = new double[getTargetFeature().alphabetSize()]; for (Pair<Double, Double> pair : neighbors) { p[pair.getFirst().intValue()] += 1d / (pair.getSecond() + 0.00000001); } return new ClassificationResult(getTargetFeature(), p); } @Override public boolean isTrained() { return index != null; } }//END OF KNN