Example usage for org.apache.lucene.classification ClassificationResult getAssignedClass

List of usage examples for org.apache.lucene.classification ClassificationResult getAssignedClass

Introduction

In this page you can find the example usage for org.apache.lucene.classification ClassificationResult getAssignedClass.

Prototype

public T getAssignedClass() 

Source Link

Document

retrieve the result class

Usage

From source file:SimpleNaiveBayesClassifier.java

License:Apache License

/**
 * Normalize the classification results based on the max score available
 * @param assignedClasses the list of assigned classes
 * @return the normalized results/*www  . java  2s .co  m*/
 */
protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(
        List<ClassificationResult<BytesRef>> assignedClasses) {
    // normalization; the values transforms to a 0-1 range
    ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
    if (!assignedClasses.isEmpty()) {
        Collections.sort(assignedClasses);
        // this is a negative number closest to 0 = a
        double smax = assignedClasses.get(0).getScore();

        double sumLog = 0;
        // log(sum(exp(x_n-a)))
        for (ClassificationResult<BytesRef> cr : assignedClasses) {
            // getScore-smax <=0 (both negative, smax is the smallest abs()
            sumLog += Math.exp(cr.getScore() - smax);
        }
        // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
        double loga = smax;
        loga += Math.log(sumLog);

        // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
        for (ClassificationResult<BytesRef> cr : assignedClasses) {
            double scoreDiff = cr.getScore() - loga;
            returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
        }
    }
    return returnList;
}

From source file:KNearestNeighborClassifier.java

License:Apache License

/**
 * build a list of classification results from search results
 * @param topDocs the search results as a {@link TopDocs} object
 * @return a {@link List} of {@link ClassificationResult}, one for each existing class
 * @throws IOException if it's not possible to get the stored value of class field
 *//*from  w  ww .j ava 2s.  co m*/
protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
    Map<BytesRef, Integer> classCounts = new HashMap<>();
    Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
    float maxScore = topDocs.getMaxScore();
    for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
        IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
        if (storableField != null) {
            BytesRef cl = new BytesRef(storableField.stringValue());
            //update count
            Integer count = classCounts.get(cl);
            if (count != null) {
                classCounts.put(cl, count + 1);
            } else {
                classCounts.put(cl, 1);
            }
            //update boost, the boost is based on the best score
            Double totalBoost = classBoosts.get(cl);
            double singleBoost = scoreDoc.score / maxScore;
            if (totalBoost != null) {
                classBoosts.put(cl, totalBoost + singleBoost);
            } else {
                classBoosts.put(cl, singleBoost);
            }
        }
    }
    List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
    List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
    int sumdoc = 0;
    for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
        Integer count = entry.getValue();
        Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
        temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
        sumdoc += count;
    }

    //correction
    if (sumdoc < k) {
        for (ClassificationResult<BytesRef> cr : temporaryList) {
            returnList.add(
                    new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
        }
    } else {
        returnList = temporaryList;
    }
    return returnList;
}

From source file:com.github.tteofili.apacheconeu14.oak.search.nls.NLSQueryIndex.java

License:Apache License

@Override
public Cursor query(Filter filter, NodeState nodeState) {

    Thread thread = Thread.currentThread();
    ClassLoader loader = thread.getContextClassLoader();
    thread.setContextClassLoader(Client.class.getClassLoader());
    try {/*www  .  j  a v a 2  s  . c  o  m*/
        final IndexSearcher searcher = IndexUtils.getSearcher();

        if (searcher != null) {

            Filter.PropertyRestriction nativeQueryRestriction = filter.getPropertyRestriction(NATIVE_NLS_QUERY);
            String nativeQueryString = String
                    .valueOf(nativeQueryRestriction.first.getValue(nativeQueryRestriction.first.getType()));

            // build the parse tree of the query and filter the uninteresting part (e.g. "who is the admin" -> "admin")
            String purgedQuery = pcfg.filterQuestion(nativeQueryString);

            BooleanQuery booleanClauses = new BooleanQuery();

            // add clauses for the purged natural language query (if existing)
            if (purgedQuery != null) {
                booleanClauses.add(new BooleanClause(new TermQuery(new Term("jcr:title", purgedQuery)),
                        BooleanClause.Occur.SHOULD));
                booleanClauses.add(new BooleanClause(new TermQuery(new Term("jcr:description", purgedQuery)),
                        BooleanClause.Occur.SHOULD));
                booleanClauses.add(new BooleanClause(new TermQuery(new Term("text", purgedQuery)),
                        BooleanClause.Occur.SHOULD));
            }

            // infer "class" of the query and boost based on that
            try {
                initializeClassifier(searcher);
                ClassificationResult<BytesRef> result = null;
                try {
                    result = classifier.assignClass(nativeQueryString);
                } catch (Exception e) {
                    // do nothing
                }
                if (result != null) {
                    booleanClauses.add(new BooleanClause(new BoostedQuery(
                            new TermQuery(new Term("jcr:primaryType", result.getAssignedClass())),
                            new ConstValueSource(2.0f)), BooleanClause.Occur.SHOULD));
                }

                final TopDocs topDocs = searcher.search(booleanClauses, 100);
                final ScoreDoc[] scoreDocs = topDocs.scoreDocs;

                return new Cursor() {
                    private int index = 0;

                    @Override
                    public IndexRow next() {

                        final ScoreDoc scoreDoc = scoreDocs[index];

                        index++;
                        return new IndexRow() {
                            @Override
                            public String getPath() {
                                try {
                                    return searcher.doc(scoreDoc.doc).get("path");
                                } catch (IOException e) {
                                    return null;
                                }
                            }

                            @Override
                            public PropertyValue getValue(String s) {
                                try {
                                    if ("jcr:score".equals(s)) {
                                        PropertyValues.newString(String.valueOf(scoreDoc.score));
                                    }
                                    return PropertyValues.newString(searcher.doc(scoreDoc.doc).get(s));
                                } catch (IOException e) {
                                    return null;
                                }
                            }
                        };
                    }

                    @Override
                    public boolean hasNext() {
                        return index < scoreDocs.length;
                    }

                    @Override
                    public void remove() {

                    }
                };
            } catch (IOException e) {
                // do nothing
            }
        }
    } finally {
        thread.setContextClassLoader(loader);
    }
    return null;
}

From source file:com.github.tteofili.looseen.MinHashClassifier.java

License:Apache License

List<ClassificationResult<BytesRef>> buildListFromTopDocs(IndexSearcher searcher, String categoryFieldName,
        TopDocs topDocs, int k) throws IOException {
    Map<BytesRef, Integer> classCounts = new HashMap<>();
    Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
    float maxScore = topDocs.getMaxScore();
    for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
        IndexableField storableField = searcher.doc(scoreDoc.doc).getField(categoryFieldName);
        if (storableField != null) {
            BytesRef cl = new BytesRef(storableField.stringValue());
            //update count
            Integer count = classCounts.get(cl);
            if (count != null) {
                classCounts.put(cl, count + 1);
            } else {
                classCounts.put(cl, 1);/* w w w. jav  a 2s.  com*/
            }
            //update boost, the boost is based on the best score
            Double totalBoost = classBoosts.get(cl);
            double singleBoost = scoreDoc.score / maxScore;
            if (totalBoost != null) {
                classBoosts.put(cl, totalBoost + singleBoost);
            } else {
                classBoosts.put(cl, singleBoost);
            }
        }
    }
    List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
    List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
    int sumdoc = 0;
    for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
        Integer count = entry.getValue();
        Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
        temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
        sumdoc += count;
    }

    //correction
    if (sumdoc < k) {
        for (ClassificationResult<BytesRef> cr : temporaryList) {
            returnList.add(
                    new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
        }
    } else {
        returnList = temporaryList;
    }
    return returnList;
}

From source file:de.uni_koeln.spinfo.textengineering.tm.classification.lucene.LuceneAdapter.java

License:Open Source License

@Override
public String classify(Document document) {
    try {//  www  .j  av  a  2  s. com
        ClassificationResult<BytesRef> result = classifier.assignClass(document.getText());
        BytesRef assignedClass = result.getAssignedClass();
        //         printAssignments(document, result);//optional
        return assignedClass.utf8ToString();
    } catch (IOException e) {
        e.printStackTrace();
    }
    return null;
}

From source file:de.uni_koeln.spinfo.textengineering.tm.classification.lucene.LuceneAdapter.java

License:Open Source License

@SuppressWarnings("unused")
private void printAssignments(Document document, ClassificationResult<BytesRef> c) {
    System.out.println("doc: " + document.getSource());
    System.out.println("class: " + c.getAssignedClass().utf8ToString());
    System.out.println("score: " + c.getScore());
    System.out.println("---------");
}

From source file:org.apache.solr.update.processor.ClassificationUpdateProcessor.java

License:Apache License

/**
 * @param cmd the update command in input containing the Document to classify
 * @throws IOException If there is a low-level I/O error
 *///  ww w  . j  a v a 2  s .co m
@Override
public void processAdd(AddUpdateCommand cmd) throws IOException {
    SolrInputDocument doc = cmd.getSolrInputDocument();
    Document luceneDocument = cmd.getLuceneDocument();
    String assignedClass;
    Object documentClass = doc.getFieldValue(trainingClassField);
    if (documentClass == null) {
        List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument,
                maxOutputClasses);
        if (assignedClassifications != null) {
            for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) {
                assignedClass = singleClassification.getAssignedClass().utf8ToString();
                doc.addField(predictedClassField, assignedClass);
            }
        }
    }
    super.processAdd(cmd);
}

From source file:org.solr.classtify.SimpleNaiveBayesClassifierTest.java

License:Apache License

@Test
public void classtify() throws IOException {
    SimpleNaiveBayesClassifier classifier = new SimpleNaiveBayesClassifier();
    IndexReader reader = DirectoryReader.open(FSDirectory.open(new File(dir)));

    AtomicReader wrap = SlowCompositeReaderWrapper.wrap(reader);
    classifier.train(wrap, textFieldName, categoryFieldName, analyzer);
    ClassificationResult<BytesRef> assignClass = classifier.assignClass(newText);
    BytesRef assignedClass = assignClass.getAssignedClass();

    double score = assignClass.getScore();
    System.out.println(assignedClass.utf8ToString() + "," + score);
}