com.meizu.nlp.classification.SimpleNaiveBayesClassifier.java Source code

Java tutorial

Introduction

Here is the source code for com.meizu.nlp.classification.SimpleNaiveBayesClassifier.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.meizu.nlp.classification;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.*;

/**
 * A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
 *
 * @lucene.experimental
 */
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {

    /**
     * {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene}'s
     * index
     */
    protected LeafReader leafReader;

    /**
     * names of the fields to be used as input text
     */
    protected String[] textFieldNames;

    /**
     * name of the field to be used as a class / category output
     */
    protected String classFieldName;

    /**
     * name of the field to be used as a rank class / category output
     */
    protected String[] rankClassFieldNames;

    /**
     * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
     */
    protected Analyzer analyzer;

    /**
     * {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
     */
    protected IndexSearcher indexSearcher;

    /**
     * {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
     */
    protected Query query;

    /**
     * Creates a new NaiveBayes classifier.
     * Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can
     * classify any documents.
     */
    public SimpleNaiveBayesClassifier() {
    }

    /**
     * {@inheritDoc}
     */
    public void train(LeafReader leafReader, String textFieldName, String classFieldName,
            String rankClassFieldNames, Analyzer analyzer) throws IOException {
        train(leafReader, textFieldName, classFieldName, analyzer, null);
        this.rankClassFieldNames = new String[] { rankClassFieldNames };
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer)
            throws IOException {
        train(leafReader, textFieldName, classFieldName, analyzer, null);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer,
            Query query) throws IOException {
        train(leafReader, new String[] { textFieldName }, classFieldName, analyzer, query);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer,
            Query query) throws IOException {
        this.leafReader = leafReader;
        this.indexSearcher = new IndexSearcher(this.leafReader);
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.analyzer = analyzer;
        this.query = query;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
        ClassificationResult<BytesRef> retval = null;
        double maxscore = -Double.MAX_VALUE;
        for (ClassificationResult<BytesRef> element : doclist) {
            if (element.getScore() > maxscore) {
                retval = element;
                maxscore = element.getScore();
            }
        }
        return retval;
    }

    public ClassificationResult<BytesRef> assignRankClass(String inputDocument) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = assignRankClassNormalizedList(inputDocument);
        ClassificationResult<BytesRef> retval = null;
        double maxscore = -Double.MAX_VALUE;
        for (ClassificationResult<BytesRef> element : doclist) {
            if (element.getScore() > maxscore) {
                retval = element;
                maxscore = element.getScore();
            }
        }
        return retval;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
        Collections.sort(doclist);
        return doclist;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
        Collections.sort(doclist);
        return doclist.subList(0, max);
    }

    private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument)
            throws IOException {
        if (leafReader == null) {
            throw new IOException("You must first call Classifier#train");
        }
        List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();

        Terms terms = MultiFields.getTerms(leafReader, classFieldName);
        TermsEnum termsEnum = terms.iterator();
        BytesRef next;
        String[] tokenizedDoc = tokenizeDoc(inputDocument);
        int docsWithClassSize = countDocsWithClass();
        int count = 0;
        while ((next = termsEnum.next()) != null) {
            double clVal = calculateLogPrior(next, docsWithClassSize)
                    + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
            dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
            count++;
        }
        // normalization; the values transforms to a 0-1 range
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
        if (!dataList.isEmpty()) {
            Collections.sort(dataList);
            // this is a negative number closest to 0 = a
            double smax = dataList.get(0).getScore();

            double sumLog = 0;
            // log(sum(exp(x_n-a)))
            for (ClassificationResult<BytesRef> cr : dataList) {
                // getScore-smax <=0 (both negative, smax is the smallest abs()
                sumLog += Math.exp(cr.getScore() - smax);
            }
            // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
            double loga = smax;
            loga += Math.log(sumLog);

            // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
            for (ClassificationResult<BytesRef> cr : dataList) {
                returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
            }
        }

        return returnList;
    }

    private List<ClassificationResult<BytesRef>> assignRankClassNormalizedList(String inputDocument)
            throws IOException {
        if (leafReader == null) {
            throw new IOException("You must first call Classifier#train");
        }
        List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();

        if (this.rankClassFieldNames == null || this.rankClassFieldNames.length == 0) {
            throw new IOException("rankClassField must defind");
        }

        for (String rankClassName : rankClassFieldNames) {

        }

        Terms terms = MultiFields.getTerms(leafReader, classFieldName);
        TermsEnum termsEnum = terms.iterator();
        BytesRef next;
        String[] tokenizedDoc = tokenizeDoc(inputDocument);
        int docsWithClassSize = countDocsWithClass();
        while ((next = termsEnum.next()) != null) {
            double clVal = calculateLogPrior(next, docsWithClassSize)
                    + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
            dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
        }

        // normalization; the values transforms to a 0-1 range
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
        if (!dataList.isEmpty()) {
            Collections.sort(dataList);
            // this is a negative number closest to 0 = a
            double smax = dataList.get(0).getScore();

            double sumLog = 0;
            // log(sum(exp(x_n-a)))
            for (ClassificationResult<BytesRef> cr : dataList) {
                // getScore-smax <=0 (both negative, smax is the smallest abs()
                sumLog += Math.exp(cr.getScore() - smax);
            }
            // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
            double loga = smax;
            loga += Math.log(sumLog);

            // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
            for (ClassificationResult<BytesRef> cr : dataList) {
                returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
            }
        }

        return returnList;
    }

    /**
     * count the number of documents in the index having at least a value for the 'class' field
     *
     * @return the no. of documents having a value for the 'class' field
     * @throws IOException if accessing to term vectors or search fails
     */
    protected int countDocsWithClass() throws IOException {
        int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
        if (docCount == -1) { // in case codec doesn't support getDocCount
            TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
            BooleanQuery q = new BooleanQuery();
            q.add(new BooleanClause(
                    new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))),
                    BooleanClause.Occur.MUST));
            if (query != null) {
                q.add(query, BooleanClause.Occur.MUST);
            }
            indexSearcher.search(q, totalHitCountCollector);
            docCount = totalHitCountCollector.getTotalHits();
        }
        return docCount;
    }

    /**
     * tokenize a <code>String</code> on this classifier's text fields and analyzer
     *
     * @param doc the <code>String</code> representing an input text (to be classified)
     * @return a <code>String</code> array of the resulting tokens
     * @throws IOException if tokenization fails
     */
    protected String[] tokenizeDoc(String doc) throws IOException {
        Collection<String> result = new LinkedList<>();
        for (String textFieldName : textFieldNames) {
            try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
                CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
                tokenStream.reset();
                while (tokenStream.incrementToken()) {
                    result.add(charTermAttribute.toString());
                }
                tokenStream.end();
            }
        }
        return result.toArray(new String[result.size()]);
    }

    private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize)
            throws IOException {
        // for each word
        double result = 0d;
        for (String word : tokenizedDoc) {
            // search with text:word AND class:c
            int hits = getWordFreqForClass(word, c);

            // num : count the no of times the word appears in documents of class c (+1)
            double num = hits + 1; // +1 is added because of add 1 smoothing

            // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
            double den = getTextTermFreqForClass(c) + docsWithClassSize;

            // P(w|c) = num/den
            double wordProbability = num / den;
            result += Math.log(wordProbability);
        }

        // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
        return result;
    }

    private double getTextTermFreqForClass(BytesRef c) throws IOException {
        double avgNumberOfUniqueTerms = 0;
        for (String textFieldName : textFieldNames) {
            Terms terms = MultiFields.getTerms(leafReader, textFieldName);
            long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
            avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
        }
        int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
        return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
    }

    private int getWordFreqForClass(String word, BytesRef c) throws IOException {
        BooleanQuery booleanQuery = new BooleanQuery();
        BooleanQuery subQuery = new BooleanQuery();
        for (String textFieldName : textFieldNames) {
            subQuery.add(
                    new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
        }
        booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
        booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
        if (query != null) {
            booleanQuery.add(query, BooleanClause.Occur.MUST);
        }
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        indexSearcher.search(booleanQuery, totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
        return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
    }

    private int docCount(BytesRef countedClass) throws IOException {
        return leafReader.docFreq(new Term(classFieldName, countedClass));
    }
}