List of usage examples for org.apache.lucene.classification ClassificationResult getAssignedClass
public T getAssignedClass()
From source file:SimpleNaiveBayesClassifier.java
License:Apache License
/** * Normalize the classification results based on the max score available * @param assignedClasses the list of assigned classes * @return the normalized results/*www . java 2s .co m*/ */ protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults( List<ClassificationResult<BytesRef>> assignedClasses) { // normalization; the values transforms to a 0-1 range ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); if (!assignedClasses.isEmpty()) { Collections.sort(assignedClasses); // this is a negative number closest to 0 = a double smax = assignedClasses.get(0).getScore(); double sumLog = 0; // log(sum(exp(x_n-a))) for (ClassificationResult<BytesRef> cr : assignedClasses) { // getScore-smax <=0 (both negative, smax is the smallest abs() sumLog += Math.exp(cr.getScore() - smax); } // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) double loga = smax; loga += Math.log(sumLog); // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) for (ClassificationResult<BytesRef> cr : assignedClasses) { double scoreDiff = cr.getScore() - loga; returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff))); } } return returnList; }
From source file:KNearestNeighborClassifier.java
License:Apache License
/** * build a list of classification results from search results * @param topDocs the search results as a {@link TopDocs} object * @return a {@link List} of {@link ClassificationResult}, one for each existing class * @throws IOException if it's not possible to get the stored value of class field *//*from w ww .j ava 2s. co m*/ protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException { Map<BytesRef, Integer> classCounts = new HashMap<>(); Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs float maxScore = topDocs.getMaxScore(); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName); if (storableField != null) { BytesRef cl = new BytesRef(storableField.stringValue()); //update count Integer count = classCounts.get(cl); if (count != null) { classCounts.put(cl, count + 1); } else { classCounts.put(cl, 1); } //update boost, the boost is based on the best score Double totalBoost = classBoosts.get(cl); double singleBoost = scoreDoc.score / maxScore; if (totalBoost != null) { classBoosts.put(cl, totalBoost + singleBoost); } else { classBoosts.put(cl, singleBoost); } } } List<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>(); int sumdoc = 0; for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) { Integer count = entry.getValue(); Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1 temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k)); sumdoc += count; } //correction if (sumdoc < k) { for (ClassificationResult<BytesRef> cr : temporaryList) { returnList.add( new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc)); } } else { returnList = temporaryList; } return returnList; }
From source file:com.github.tteofili.apacheconeu14.oak.search.nls.NLSQueryIndex.java
License:Apache License
@Override public Cursor query(Filter filter, NodeState nodeState) { Thread thread = Thread.currentThread(); ClassLoader loader = thread.getContextClassLoader(); thread.setContextClassLoader(Client.class.getClassLoader()); try {/*www . j a v a 2 s . c o m*/ final IndexSearcher searcher = IndexUtils.getSearcher(); if (searcher != null) { Filter.PropertyRestriction nativeQueryRestriction = filter.getPropertyRestriction(NATIVE_NLS_QUERY); String nativeQueryString = String .valueOf(nativeQueryRestriction.first.getValue(nativeQueryRestriction.first.getType())); // build the parse tree of the query and filter the uninteresting part (e.g. "who is the admin" -> "admin") String purgedQuery = pcfg.filterQuestion(nativeQueryString); BooleanQuery booleanClauses = new BooleanQuery(); // add clauses for the purged natural language query (if existing) if (purgedQuery != null) { booleanClauses.add(new BooleanClause(new TermQuery(new Term("jcr:title", purgedQuery)), BooleanClause.Occur.SHOULD)); booleanClauses.add(new BooleanClause(new TermQuery(new Term("jcr:description", purgedQuery)), BooleanClause.Occur.SHOULD)); booleanClauses.add(new BooleanClause(new TermQuery(new Term("text", purgedQuery)), BooleanClause.Occur.SHOULD)); } // infer "class" of the query and boost based on that try { initializeClassifier(searcher); ClassificationResult<BytesRef> result = null; try { result = classifier.assignClass(nativeQueryString); } catch (Exception e) { // do nothing } if (result != null) { booleanClauses.add(new BooleanClause(new BoostedQuery( new TermQuery(new Term("jcr:primaryType", result.getAssignedClass())), new ConstValueSource(2.0f)), BooleanClause.Occur.SHOULD)); } final TopDocs topDocs = searcher.search(booleanClauses, 100); final ScoreDoc[] scoreDocs = topDocs.scoreDocs; return new Cursor() { private int index = 0; @Override public IndexRow next() { final ScoreDoc scoreDoc = scoreDocs[index]; index++; return new IndexRow() { @Override public String getPath() { try { return searcher.doc(scoreDoc.doc).get("path"); } catch (IOException e) { return null; } } @Override public PropertyValue getValue(String s) { try { if ("jcr:score".equals(s)) { PropertyValues.newString(String.valueOf(scoreDoc.score)); } return PropertyValues.newString(searcher.doc(scoreDoc.doc).get(s)); } catch (IOException e) { return null; } } }; } @Override public boolean hasNext() { return index < scoreDocs.length; } @Override public void remove() { } }; } catch (IOException e) { // do nothing } } } finally { thread.setContextClassLoader(loader); } return null; }
From source file:com.github.tteofili.looseen.MinHashClassifier.java
License:Apache License
List<ClassificationResult<BytesRef>> buildListFromTopDocs(IndexSearcher searcher, String categoryFieldName,
TopDocs topDocs, int k) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
float maxScore = topDocs.getMaxScore();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
IndexableField storableField = searcher.doc(scoreDoc.doc).getField(categoryFieldName);
if (storableField != null) {
BytesRef cl = new BytesRef(storableField.stringValue());
//update count
Integer count = classCounts.get(cl);
if (count != null) {
classCounts.put(cl, count + 1);
} else {
classCounts.put(cl, 1);/* w w w. jav a 2s. com*/
}
//update boost, the boost is based on the best score
Double totalBoost = classBoosts.get(cl);
double singleBoost = scoreDoc.score / maxScore;
if (totalBoost != null) {
classBoosts.put(cl, totalBoost + singleBoost);
} else {
classBoosts.put(cl, singleBoost);
}
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
int sumdoc = 0;
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
sumdoc += count;
}
//correction
if (sumdoc < k) {
for (ClassificationResult<BytesRef> cr : temporaryList) {
returnList.add(
new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
}
} else {
returnList = temporaryList;
}
return returnList;
}
From source file:de.uni_koeln.spinfo.textengineering.tm.classification.lucene.LuceneAdapter.java
License:Open Source License
@Override public String classify(Document document) { try {// www .j av a 2 s. com ClassificationResult<BytesRef> result = classifier.assignClass(document.getText()); BytesRef assignedClass = result.getAssignedClass(); // printAssignments(document, result);//optional return assignedClass.utf8ToString(); } catch (IOException e) { e.printStackTrace(); } return null; }
From source file:de.uni_koeln.spinfo.textengineering.tm.classification.lucene.LuceneAdapter.java
License:Open Source License
@SuppressWarnings("unused") private void printAssignments(Document document, ClassificationResult<BytesRef> c) { System.out.println("doc: " + document.getSource()); System.out.println("class: " + c.getAssignedClass().utf8ToString()); System.out.println("score: " + c.getScore()); System.out.println("---------"); }
From source file:org.apache.solr.update.processor.ClassificationUpdateProcessor.java
License:Apache License
/** * @param cmd the update command in input containing the Document to classify * @throws IOException If there is a low-level I/O error */// ww w . j a v a 2 s .co m @Override public void processAdd(AddUpdateCommand cmd) throws IOException { SolrInputDocument doc = cmd.getSolrInputDocument(); Document luceneDocument = cmd.getLuceneDocument(); String assignedClass; Object documentClass = doc.getFieldValue(trainingClassField); if (documentClass == null) { List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses); if (assignedClassifications != null) { for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) { assignedClass = singleClassification.getAssignedClass().utf8ToString(); doc.addField(predictedClassField, assignedClass); } } } super.processAdd(cmd); }
From source file:org.solr.classtify.SimpleNaiveBayesClassifierTest.java
License:Apache License
@Test public void classtify() throws IOException { SimpleNaiveBayesClassifier classifier = new SimpleNaiveBayesClassifier(); IndexReader reader = DirectoryReader.open(FSDirectory.open(new File(dir))); AtomicReader wrap = SlowCompositeReaderWrapper.wrap(reader); classifier.train(wrap, textFieldName, categoryFieldName, analyzer); ClassificationResult<BytesRef> assignClass = classifier.assignClass(newText); BytesRef assignedClass = assignClass.getAssignedClass(); double score = assignClass.getScore(); System.out.println(assignedClass.utf8ToString() + "," + score); }