Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.meizu.nlp.classification; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.index.*; import org.apache.lucene.search.*; import org.apache.lucene.util.BytesRef; import java.io.IOException; import java.util.*; /** * A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code> * * @lucene.experimental */ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> { /** * {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene}'s * index */ protected LeafReader leafReader; /** * names of the fields to be used as input text */ protected String[] textFieldNames; /** * name of the field to be used as a class / category output */ protected String classFieldName; /** * name of the field to be used as a rank class / category output */ protected String[] rankClassFieldNames; /** * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text */ protected Analyzer analyzer; /** * {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies */ protected IndexSearcher indexSearcher; /** * {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify */ protected Query query; /** * Creates a new NaiveBayes classifier. * Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can * classify any documents. */ public SimpleNaiveBayesClassifier() { } /** * {@inheritDoc} */ public void train(LeafReader leafReader, String textFieldName, String classFieldName, String rankClassFieldNames, Analyzer analyzer) throws IOException { train(leafReader, textFieldName, classFieldName, analyzer, null); this.rankClassFieldNames = new String[] { rankClassFieldNames }; } /** * {@inheritDoc} */ @Override public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { train(leafReader, textFieldName, classFieldName, analyzer, null); } /** * {@inheritDoc} */ @Override public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { train(leafReader, new String[] { textFieldName }, classFieldName, analyzer, query); } /** * {@inheritDoc} */ @Override public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException { this.leafReader = leafReader; this.indexSearcher = new IndexSearcher(this.leafReader); this.textFieldNames = textFieldNames; this.classFieldName = classFieldName; this.analyzer = analyzer; this.query = query; } /** * {@inheritDoc} */ @Override public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException { List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument); ClassificationResult<BytesRef> retval = null; double maxscore = -Double.MAX_VALUE; for (ClassificationResult<BytesRef> element : doclist) { if (element.getScore() > maxscore) { retval = element; maxscore = element.getScore(); } } return retval; } public ClassificationResult<BytesRef> assignRankClass(String inputDocument) throws IOException { List<ClassificationResult<BytesRef>> doclist = assignRankClassNormalizedList(inputDocument); ClassificationResult<BytesRef> retval = null; double maxscore = -Double.MAX_VALUE; for (ClassificationResult<BytesRef> element : doclist) { if (element.getScore() > maxscore) { retval = element; maxscore = element.getScore(); } } return retval; } /** * {@inheritDoc} */ @Override public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException { List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text); Collections.sort(doclist); return doclist; } /** * {@inheritDoc} */ @Override public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException { List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text); Collections.sort(doclist); return doclist.subList(0, max); } private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException { if (leafReader == null) { throw new IOException("You must first call Classifier#train"); } List<ClassificationResult<BytesRef>> dataList = new ArrayList<>(); Terms terms = MultiFields.getTerms(leafReader, classFieldName); TermsEnum termsEnum = terms.iterator(); BytesRef next; String[] tokenizedDoc = tokenizeDoc(inputDocument); int docsWithClassSize = countDocsWithClass(); int count = 0; while ((next = termsEnum.next()) != null) { double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize); dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal)); count++; } // normalization; the values transforms to a 0-1 range ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); if (!dataList.isEmpty()) { Collections.sort(dataList); // this is a negative number closest to 0 = a double smax = dataList.get(0).getScore(); double sumLog = 0; // log(sum(exp(x_n-a))) for (ClassificationResult<BytesRef> cr : dataList) { // getScore-smax <=0 (both negative, smax is the smallest abs() sumLog += Math.exp(cr.getScore() - smax); } // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) double loga = smax; loga += Math.log(sumLog); // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) for (ClassificationResult<BytesRef> cr : dataList) { returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga))); } } return returnList; } private List<ClassificationResult<BytesRef>> assignRankClassNormalizedList(String inputDocument) throws IOException { if (leafReader == null) { throw new IOException("You must first call Classifier#train"); } List<ClassificationResult<BytesRef>> dataList = new ArrayList<>(); if (this.rankClassFieldNames == null || this.rankClassFieldNames.length == 0) { throw new IOException("rankClassField must defind"); } for (String rankClassName : rankClassFieldNames) { } Terms terms = MultiFields.getTerms(leafReader, classFieldName); TermsEnum termsEnum = terms.iterator(); BytesRef next; String[] tokenizedDoc = tokenizeDoc(inputDocument); int docsWithClassSize = countDocsWithClass(); while ((next = termsEnum.next()) != null) { double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize); dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal)); } // normalization; the values transforms to a 0-1 range ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); if (!dataList.isEmpty()) { Collections.sort(dataList); // this is a negative number closest to 0 = a double smax = dataList.get(0).getScore(); double sumLog = 0; // log(sum(exp(x_n-a))) for (ClassificationResult<BytesRef> cr : dataList) { // getScore-smax <=0 (both negative, smax is the smallest abs() sumLog += Math.exp(cr.getScore() - smax); } // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) double loga = smax; loga += Math.log(sumLog); // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) for (ClassificationResult<BytesRef> cr : dataList) { returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga))); } } return returnList; } /** * count the number of documents in the index having at least a value for the 'class' field * * @return the no. of documents having a value for the 'class' field * @throws IOException if accessing to term vectors or search fails */ protected int countDocsWithClass() throws IOException { int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount(); if (docCount == -1) { // in case codec doesn't support getDocCount TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); BooleanQuery q = new BooleanQuery(); q.add(new BooleanClause( new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST)); if (query != null) { q.add(query, BooleanClause.Occur.MUST); } indexSearcher.search(q, totalHitCountCollector); docCount = totalHitCountCollector.getTotalHits(); } return docCount; } /** * tokenize a <code>String</code> on this classifier's text fields and analyzer * * @param doc the <code>String</code> representing an input text (to be classified) * @return a <code>String</code> array of the resulting tokens * @throws IOException if tokenization fails */ protected String[] tokenizeDoc(String doc) throws IOException { Collection<String> result = new LinkedList<>(); for (String textFieldName : textFieldNames) { try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) { CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); tokenStream.reset(); while (tokenStream.incrementToken()) { result.add(charTermAttribute.toString()); } tokenStream.end(); } } return result.toArray(new String[result.size()]); } private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException { // for each word double result = 0d; for (String word : tokenizedDoc) { // search with text:word AND class:c int hits = getWordFreqForClass(word, c); // num : count the no of times the word appears in documents of class c (+1) double num = hits + 1; // +1 is added because of add 1 smoothing // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) double den = getTextTermFreqForClass(c) + docsWithClassSize; // P(w|c) = num/den double wordProbability = num / den; result += Math.log(wordProbability); } // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) return result; } private double getTextTermFreqForClass(BytesRef c) throws IOException { double avgNumberOfUniqueTerms = 0; for (String textFieldName : textFieldNames) { Terms terms = MultiFields.getTerms(leafReader, textFieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc } int docsWithC = leafReader.docFreq(new Term(classFieldName, c)); return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c } private int getWordFreqForClass(String word, BytesRef c) throws IOException { BooleanQuery booleanQuery = new BooleanQuery(); BooleanQuery subQuery = new BooleanQuery(); for (String textFieldName : textFieldNames) { subQuery.add( new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD)); } booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); if (query != null) { booleanQuery.add(query, BooleanClause.Occur.MUST); } TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); indexSearcher.search(booleanQuery, totalHitCountCollector); return totalHitCountCollector.getTotalHits(); } private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException { return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize); } private int docCount(BytesRef countedClass) throws IOException { return leafReader.docFreq(new Term(classFieldName, countedClass)); } }