Example usage for org.apache.lucene.classification SimpleNaiveBayesClassifier SimpleNaiveBayesClassifier

List of usage examples for org.apache.lucene.classification SimpleNaiveBayesClassifier SimpleNaiveBayesClassifier

Introduction

In this page you can find the example usage for org.apache.lucene.classification SimpleNaiveBayesClassifier SimpleNaiveBayesClassifier.

Prototype

public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query,
        String classFieldName, String... textFieldNames) 

Source Link

Document

Creates a new NaiveBayes classifier.

Usage

From source file:com.github.tteofili.looseen.Test20NewsgroupsClassification.java

License:Apache License

@Test
public void test20Newsgroups() throws Exception {

    String indexProperty = System.getProperty("index");
    if (indexProperty != null) {
        try {//from   w  ww.  jav a2  s . c om
            index = Boolean.valueOf(indexProperty);
        } catch (Exception e) {
            // ignore
        }
    }

    String splitProperty = System.getProperty("split");
    if (splitProperty != null) {
        try {
            split = Boolean.valueOf(splitProperty);
        } catch (Exception e) {
            // ignore
        }
    }

    Path mainIndexPath = Paths.get(INDEX + "/original");
    Directory directory = FSDirectory.open(mainIndexPath);
    Path trainPath = Paths.get(INDEX + "/train");
    Path testPath = Paths.get(INDEX + "/test");
    Path cvPath = Paths.get(INDEX + "/cv");
    FSDirectory cv = null;
    FSDirectory test = null;
    FSDirectory train = null;
    IndexReader testReader = null;
    if (split) {
        cv = FSDirectory.open(cvPath);
        test = FSDirectory.open(testPath);
        train = FSDirectory.open(trainPath);
    }

    if (index) {
        delete(mainIndexPath);
        if (split) {
            delete(trainPath, testPath, cvPath);
        }
    }

    IndexReader reader = null;
    List<Classifier<BytesRef>> classifiers = new LinkedList<>();
    try {
        Analyzer analyzer = new StandardAnalyzer();
        if (index) {

            System.out.format("Indexing 20 Newsgroups...%n");

            long startIndex = System.currentTimeMillis();
            IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig(analyzer));

            buildIndex(new File(PREFIX + "/20n/20_newsgroups"), indexWriter);

            long endIndex = System.currentTimeMillis();
            System.out.format("Indexed %d pages in %ds %n", indexWriter.maxDoc(),
                    (endIndex - startIndex) / 1000);

            indexWriter.close();

        }

        if (split && !index) {
            reader = DirectoryReader.open(train);
        } else {
            reader = DirectoryReader.open(directory);
        }

        if (index && split) {
            // split the index
            System.out.format("Splitting the index...%n");

            long startSplit = System.currentTimeMillis();
            DatasetSplitter datasetSplitter = new DatasetSplitter(0.1, 0);
            datasetSplitter.split(reader, train, test, cv, analyzer, false, CATEGORY_FIELD, BODY_FIELD,
                    SUBJECT_FIELD, CATEGORY_FIELD);
            reader.close();
            reader = DirectoryReader.open(train); // using the train index from now on
            long endSplit = System.currentTimeMillis();
            System.out.format("Splitting done in %ds %n", (endSplit - startSplit) / 1000);
        }

        final long startTime = System.currentTimeMillis();

        classifiers.add(new KNearestNeighborClassifier(reader, new ClassicSimilarity(), analyzer, null, 1, 0, 0,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, CATEGORY_FIELD,
                BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new ClassicSimilarity(), analyzer, null, 3, 0, 0,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new AxiomaticF1EXP(), analyzer, null, 3, 0, 0,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new AxiomaticF1LOG(), analyzer, null, 3, 0, 0,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new LMDirichletSimilarity(), analyzer, null, 3,
                1, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new LMJelinekMercerSimilarity(0.3f), analyzer,
                null, 3, 1, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, null, analyzer, null, 3, 1, 1, CATEGORY_FIELD,
                BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new DFRSimilarity(new BasicModelG(), new AfterEffectB(), new NormalizationH1()), analyzer, null,
                3, 1, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new DFRSimilarity(new BasicModelP(), new AfterEffectL(), new NormalizationH3()), analyzer, null,
                3, 1, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new IBSimilarity(new DistributionSPL(), new LambdaDF(), new Normalization.NoNormalization()),
                analyzer, null, 3, 1, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new IBSimilarity(new DistributionLL(), new LambdaTTF(), new NormalizationH1()), analyzer, null,
                3, 1, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new MinHashClassifier(reader, BODY_FIELD, CATEGORY_FIELD, 15, 1, 100));
        classifiers.add(new MinHashClassifier(reader, BODY_FIELD, CATEGORY_FIELD, 30, 3, 300));
        classifiers.add(new MinHashClassifier(reader, BODY_FIELD, CATEGORY_FIELD, 10, 1, 100));
        classifiers.add(new KNearestFuzzyClassifier(reader, new LMJelinekMercerSimilarity(0.3f), analyzer, null,
                1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader,
                new IBSimilarity(new DistributionLL(), new LambdaTTF(), new NormalizationH1()), analyzer, null,
                1, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new ClassicSimilarity(), analyzer, null, 1,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new ClassicSimilarity(), analyzer, null, 3,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers
                .add(new KNearestFuzzyClassifier(reader, null, analyzer, null, 1, CATEGORY_FIELD, BODY_FIELD));
        classifiers
                .add(new KNearestFuzzyClassifier(reader, null, analyzer, null, 3, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new AxiomaticF1EXP(), analyzer, null, 3,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new AxiomaticF1LOG(), analyzer, null, 3,
                CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new BM25NBClassifier(reader, analyzer, null, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new CachingNaiveBayesClassifier(reader, analyzer, null, CATEGORY_FIELD, BODY_FIELD));
        classifiers.add(new SimpleNaiveBayesClassifier(reader, analyzer, null, CATEGORY_FIELD, BODY_FIELD));

        int maxdoc;

        if (split) {
            testReader = DirectoryReader.open(test);
            maxdoc = testReader.maxDoc();
        } else {
            maxdoc = reader.maxDoc();
        }

        System.out.format("Starting evaluation on %d docs...%n", maxdoc);

        ExecutorService service = Executors.newCachedThreadPool();
        List<Future<String>> futures = new LinkedList<>();
        for (Classifier<BytesRef> classifier : classifiers) {
            testClassifier(reader, startTime, testReader, service, futures, classifier);
        }
        for (Future<String> f : futures) {
            System.out.println(f.get());
        }

        Thread.sleep(10000);
        service.shutdown();

    } finally {
        if (reader != null) {
            reader.close();
        }
        directory.close();
        if (test != null) {
            test.close();
        }
        if (train != null) {
            train.close();
        }
        if (cv != null) {
            cv.close();
        }
        if (testReader != null) {
            testReader.close();
        }

        for (Classifier c : classifiers) {
            if (c instanceof Closeable) {
                ((Closeable) c).close();
            }
        }
    }
}

From source file:com.github.tteofili.looseen.TestWikipediaClassification.java

License:Apache License

@Test
public void testItalianWikipedia() throws Exception {

    String indexProperty = System.getProperty("index");
    if (indexProperty != null) {
        try {//  www .j  a v a 2 s  .  co  m
            index = Boolean.valueOf(indexProperty);
        } catch (Exception e) {
            // ignore
        }
    }

    String splitProperty = System.getProperty("split");
    if (splitProperty != null) {
        try {
            split = Boolean.valueOf(splitProperty);
        } catch (Exception e) {
            // ignore
        }
    }

    Path mainIndexPath = Paths.get(INDEX + "/original");
    Directory directory = FSDirectory.open(mainIndexPath);
    Path trainPath = Paths.get(INDEX + "/train");
    Path testPath = Paths.get(INDEX + "/test");
    Path cvPath = Paths.get(INDEX + "/cv");
    FSDirectory cv = null;
    FSDirectory test = null;
    FSDirectory train = null;
    DirectoryReader testReader = null;
    if (split) {
        cv = FSDirectory.open(cvPath);
        test = FSDirectory.open(testPath);
        train = FSDirectory.open(trainPath);
    }

    if (index) {
        delete(mainIndexPath);
        if (split) {
            delete(trainPath, testPath, cvPath);
        }
    }

    IndexReader reader = null;
    try {
        Collection<String> stopWordsList = Arrays.asList("di", "a", "da", "in", "per", "tra", "fra", "il", "lo",
                "la", "i", "gli", "le");
        CharArraySet stopWords = new CharArraySet(stopWordsList, true);
        Analyzer analyzer = new ItalianAnalyzer(stopWords);
        if (index) {

            System.out.format("Indexing Italian Wikipedia...%n");

            long startIndex = System.currentTimeMillis();
            IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig(analyzer));

            importWikipedia(new File(PREFIX + "/itwiki/itwiki-20150405-pages-meta-current1.xml"), indexWriter);
            importWikipedia(new File(PREFIX + "/itwiki/itwiki-20150405-pages-meta-current2.xml"), indexWriter);
            importWikipedia(new File(PREFIX + "/itwiki/itwiki-20150405-pages-meta-current3.xml"), indexWriter);
            importWikipedia(new File(PREFIX + "/itwiki/itwiki-20150405-pages-meta-current4.xml"), indexWriter);

            long endIndex = System.currentTimeMillis();
            System.out.format("Indexed %d pages in %ds %n", indexWriter.maxDoc(),
                    (endIndex - startIndex) / 1000);

            indexWriter.close();

        }

        if (split && !index) {
            reader = DirectoryReader.open(train);
        } else {
            reader = DirectoryReader.open(directory);
        }

        if (index && split) {
            // split the index
            System.out.format("Splitting the index...%n");

            long startSplit = System.currentTimeMillis();
            DatasetSplitter datasetSplitter = new DatasetSplitter(0.1, 0);
            for (LeafReaderContext context : reader.leaves()) {
                datasetSplitter.split(context.reader(), train, test, cv, analyzer, false, CATEGORY_FIELD,
                        TEXT_FIELD, CATEGORY_FIELD);
            }
            reader.close();
            reader = DirectoryReader.open(train); // using the train index from now on
            long endSplit = System.currentTimeMillis();
            System.out.format("Splitting done in %ds %n", (endSplit - startSplit) / 1000);
        }

        final long startTime = System.currentTimeMillis();

        List<Classifier<BytesRef>> classifiers = new LinkedList<>();
        classifiers.add(new KNearestNeighborClassifier(reader, new ClassicSimilarity(), analyzer, null, 1, 0, 0,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new BM25Similarity(), analyzer, null, 1, 0, 0,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, CATEGORY_FIELD,
                TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new LMDirichletSimilarity(), analyzer, null, 3,
                1, 1, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new LMJelinekMercerSimilarity(0.3f), analyzer,
                null, 3, 1, 1, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new ClassicSimilarity(), analyzer, null, 3, 0, 0,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader, new ClassicSimilarity(), analyzer, null, 3, 1, 1,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new DFRSimilarity(new BasicModelG(), new AfterEffectB(), new NormalizationH1()), analyzer, null,
                3, 1, 1, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new DFRSimilarity(new BasicModelP(), new AfterEffectL(), new NormalizationH3()), analyzer, null,
                3, 1, 1, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new IBSimilarity(new DistributionSPL(), new LambdaDF(), new Normalization.NoNormalization()),
                analyzer, null, 3, 1, 1, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestNeighborClassifier(reader,
                new IBSimilarity(new DistributionLL(), new LambdaTTF(), new NormalizationH1()), analyzer, null,
                3, 1, 1, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new MinHashClassifier(reader, TEXT_FIELD, CATEGORY_FIELD, 5, 1, 100));
        classifiers.add(new MinHashClassifier(reader, TEXT_FIELD, CATEGORY_FIELD, 10, 1, 100));
        classifiers.add(new MinHashClassifier(reader, TEXT_FIELD, CATEGORY_FIELD, 15, 1, 100));
        classifiers.add(new MinHashClassifier(reader, TEXT_FIELD, CATEGORY_FIELD, 15, 3, 100));
        classifiers.add(new MinHashClassifier(reader, TEXT_FIELD, CATEGORY_FIELD, 15, 3, 300));
        classifiers.add(new MinHashClassifier(reader, TEXT_FIELD, CATEGORY_FIELD, 5, 3, 100));
        classifiers.add(new KNearestFuzzyClassifier(reader, new ClassicSimilarity(), analyzer, null, 3,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new ClassicSimilarity(), analyzer, null, 1,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new BM25Similarity(), analyzer, null, 3,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new KNearestFuzzyClassifier(reader, new BM25Similarity(), analyzer, null, 1,
                CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new BM25NBClassifier(reader, analyzer, null, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new CachingNaiveBayesClassifier(reader, analyzer, null, CATEGORY_FIELD, TEXT_FIELD));
        classifiers.add(new SimpleNaiveBayesClassifier(reader, analyzer, null, CATEGORY_FIELD, TEXT_FIELD));

        int maxdoc;

        if (split) {
            testReader = DirectoryReader.open(test);
            maxdoc = testReader.maxDoc();
        } else {
            maxdoc = reader.maxDoc();
        }

        System.out.format("Starting evaluation on %d docs...%n", maxdoc);

        ExecutorService service = Executors.newCachedThreadPool();
        List<Future<String>> futures = new LinkedList<>();
        for (Classifier<BytesRef> classifier : classifiers) {

            final IndexReader finalReader = reader;
            final DirectoryReader finalTestReader = testReader;
            futures.add(service.submit(() -> {
                ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix;
                if (split) {
                    confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(finalTestReader, classifier,
                            CATEGORY_FIELD, TEXT_FIELD, 60000 * 30);
                } else {
                    confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(finalReader, classifier,
                            CATEGORY_FIELD, TEXT_FIELD, 60000 * 30);
                }

                final long endTime = System.currentTimeMillis();
                final int elapse = (int) (endTime - startTime) / 1000;

                return " * " + classifier + " \n    * accuracy = " + confusionMatrix.getAccuracy()
                        + "\n    * precision = " + confusionMatrix.getPrecision() + "\n    * recall = "
                        + confusionMatrix.getRecall() + "\n    * f1-measure = " + confusionMatrix.getF1Measure()
                        + "\n    * avgClassificationTime = " + confusionMatrix.getAvgClassificationTime()
                        + "\n    * time = " + elapse + " (sec)\n ";
            }));

        }
        for (Future<String> f : futures) {
            System.out.println(f.get());
        }

        Thread.sleep(10000);
        service.shutdown();

    } finally {
        try {
            if (reader != null) {
                reader.close();
            }
            if (directory != null) {
                directory.close();
            }
            if (test != null) {
                test.close();
            }
            if (train != null) {
                train.close();
            }
            if (cv != null) {
                cv.close();
            }
            if (testReader != null) {
                testReader.close();
            }
        } catch (Throwable e) {
            e.printStackTrace();
        }
    }
}