org.apache.lucene.classification.KNearestNeighborClassifier.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.lucene.classification.KNearestNeighborClassifier.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;

/**
 * A k-Nearest Neighbor classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
 * on {@link MoreLikeThis}
 *
 * @lucene.experimental
 */
public class KNearestNeighborClassifier implements Classifier<BytesRef> {

    /**
     * a {@link MoreLikeThis} instance used to perform MLT queries
     */
    protected final MoreLikeThis mlt;

    /**
     * the name of the fields used as the input text
     */
    protected final String[] textFieldNames;

    /**
     * the name of the field used as the output text
     */
    protected final String classFieldName;

    /**
     * an {@link IndexSearcher} used to perform queries
     */
    protected final IndexSearcher indexSearcher;

    /**
     * the no. of docs to compare in order to find the nearest neighbor to the input text
     */
    protected final int k;

    /**
     * a {@link Query} used to filter the documents that should be used from this classifier's underlying {@link LeafReader}
     */
    protected final Query query;

    /**
     * Creates a {@link KNearestNeighborClassifier}.
     *
     * @param indexReader     the reader on the index to be used for classification
     * @param analyzer       an {@link Analyzer} used to analyze unseen text
     * @param similarity     the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
     *                       (defaults to {@link org.apache.lucene.search.similarities.BM25Similarity})
     * @param query          a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
     *                       if all the indexed docs should be used
     * @param k              the no. of docs to select in the MLT results to find the nearest neighbor
     * @param minDocsFreq    {@link MoreLikeThis#minDocFreq} parameter
     * @param minTermFreq    {@link MoreLikeThis#minTermFreq} parameter
     * @param classFieldName the name of the field used as the output for the classifier
     * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
     */
    public KNearestNeighborClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer,
            Query query, int k, int minDocsFreq, int minTermFreq, String classFieldName, String... textFieldNames) {
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.mlt = new MoreLikeThis(indexReader);
        this.mlt.setAnalyzer(analyzer);
        this.mlt.setFieldNames(textFieldNames);
        this.indexSearcher = new IndexSearcher(indexReader);
        if (similarity != null) {
            this.indexSearcher.setSimilarity(similarity);
        } else {
            this.indexSearcher.setSimilarity(new BM25Similarity());
        }
        if (minDocsFreq > 0) {
            mlt.setMinDocFreq(minDocsFreq);
        }
        if (minTermFreq > 0) {
            mlt.setMinTermFreq(minTermFreq);
        }
        this.query = query;
        this.k = k;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
        return classifyFromTopDocs(knnSearch(text));
    }

    /**
     * TODO
     */
    protected ClassificationResult<BytesRef> classifyFromTopDocs(TopDocs knnResults) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
        ClassificationResult<BytesRef> assignedClass = null;
        double maxscore = -Double.MAX_VALUE;
        for (ClassificationResult<BytesRef> cl : assignedClasses) {
            if (cl.getScore() > maxscore) {
                assignedClass = cl;
                maxscore = cl.getScore();
            }
        }
        return assignedClass;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        TopDocs knnResults = knnSearch(text);
        List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
        Collections.sort(assignedClasses);
        return assignedClasses;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        TopDocs knnResults = knnSearch(text);
        List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
        Collections.sort(assignedClasses);
        return assignedClasses.subList(0, max);
    }

    private TopDocs knnSearch(String text) throws IOException {
        BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
        for (String fieldName : textFieldNames) {
            String boost = null;
            mlt.setBoost(true); //terms boost actually helps in MLT queries
            if (fieldName.contains("^")) {
                String[] field2boost = fieldName.split("\\^");
                fieldName = field2boost[0];
                boost = field2boost[1];
            }
            if (boost != null) {
                mlt.setBoostFactor(Float.parseFloat(boost));//if we have a field boost, we add it
            }
            mltQuery.add(
                    new BooleanClause(mlt.like(fieldName, new StringReader(text)), BooleanClause.Occur.SHOULD));
            mlt.setBoostFactor(1);// restore neutral boost for next field
        }
        Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
        mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
        if (query != null) {
            mltQuery.add(query, BooleanClause.Occur.MUST);
        }
        return indexSearcher.search(mltQuery.build(), k);
    }

    //ranking of classes must be taken in consideration
    /**
     * build a list of classification results from search results
     * @param topDocs the search results as a {@link TopDocs} object
     * @return a {@link List} of {@link ClassificationResult}, one for each existing class
     * @throws IOException if it's not possible to get the stored value of class field
     */
    protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
        Map<BytesRef, Integer> classCounts = new HashMap<>();
        Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
        float maxScore = topDocs.totalHits.value == 0 ? Float.NaN : topDocs.scoreDocs[0].score;
        for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
            IndexableField[] storableFields = indexSearcher.doc(scoreDoc.doc).getFields(classFieldName);
            for (IndexableField singleStorableField : storableFields) {
                if (singleStorableField != null) {
                    BytesRef cl = new BytesRef(singleStorableField.stringValue());
                    //update count
                    Integer count = classCounts.get(cl);
                    if (count != null) {
                        classCounts.put(cl, count + 1);
                    } else {
                        classCounts.put(cl, 1);
                    }
                    //update boost, the boost is based on the best score
                    Double totalBoost = classBoosts.get(cl);
                    double singleBoost = scoreDoc.score / maxScore;
                    if (totalBoost != null) {
                        classBoosts.put(cl, totalBoost + singleBoost);
                    } else {
                        classBoosts.put(cl, singleBoost);
                    }
                }
            }
        }
        List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
        List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
        int sumdoc = 0;
        for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
            Integer count = entry.getValue();
            Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
            temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
            sumdoc += count;
        }

        //correction
        if (sumdoc < k) {
            for (ClassificationResult<BytesRef> cr : temporaryList) {
                returnList.add(
                        new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
            }
        } else {
            returnList = temporaryList;
        }
        return returnList;
    }

    @Override
    public String toString() {
        return "KNearestNeighborClassifier{" + "textFieldNames=" + Arrays.toString(textFieldNames)
                + ", classFieldName='" + classFieldName + '\'' + ", k=" + k + ", query=" + query + ", similarity="
                + indexSearcher.getSimilarity() + '}';
    }
}