org.apache.lucene.classification.utils.DatasetSplitter.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.lucene.classification.utils.DatasetSplitter.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification.utils;

import java.io.IOException;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;

/**
 * Utility class for creating training / test / cross validation indexes from the original index.
 */
public class DatasetSplitter {

    private final double crossValidationRatio;
    private final double testRatio;

    /**
     * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
     *
     * @param testRatio            the ratio of the original index to be used for the test IDX as a <code>double</code> between 0.0 and 1.0
     * @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a <code>double</code> between 0.0 and 1.0
     */
    public DatasetSplitter(double testRatio, double crossValidationRatio) {
        this.crossValidationRatio = crossValidationRatio;
        this.testRatio = testRatio;
    }

    /**
     * Split a given index into 3 indexes for training, test and cross validation tasks respectively
     *
     * @param originalIndex        an {@link org.apache.lucene.index.LeafReader} on the source index
     * @param trainingIndex        a {@link Directory} used to write the training index
     * @param testIndex            a {@link Directory} used to write the test index
     * @param crossValidationIndex a {@link Directory} used to write the cross validation index
     * @param analyzer             {@link Analyzer} used to create the new docs
     * @param termVectors          {@code true} if term vectors should be kept
     * @param classFieldName       name of the field used as the label for classification; this must be indexed with sorted doc values
     * @param fieldNames           names of fields that need to be put in the new indexes or <code>null</code> if all should be used
     * @throws IOException if any writing operation fails on any of the indexes
     */
    public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex,
            Directory crossValidationIndex, Analyzer analyzer, boolean termVectors, String classFieldName,
            String... fieldNames) throws IOException {

        // create IWs for train / test / cv IDXs
        IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
        IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
        IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));

        // get the exact no. of existing classes
        int noOfClasses = 0;
        for (LeafReaderContext leave : originalIndex.leaves()) {
            long valueCount = 0;
            SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
            if (classValues != null) {
                valueCount = classValues.getValueCount();
            } else {
                SortedSetDocValues sortedSetDocValues = leave.reader().getSortedSetDocValues(classFieldName);
                if (sortedSetDocValues != null) {
                    valueCount = sortedSetDocValues.getValueCount();
                }
            }
            if (classValues == null) {
                // approximate with no. of terms
                noOfClasses += leave.reader().terms(classFieldName).size();
            }
            noOfClasses += valueCount;
        }

        try {

            IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
            GroupingSearch gs = new GroupingSearch(classFieldName);
            gs.setGroupSort(Sort.INDEXORDER);
            gs.setSortWithinGroup(Sort.INDEXORDER);
            gs.setAllGroups(true);
            gs.setGroupDocsLimit(originalIndex.maxDoc());
            TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, noOfClasses);

            // set the type to be indexed, stored, with term vectors
            FieldType ft = new FieldType(TextField.TYPE_STORED);
            if (termVectors) {
                ft.setStoreTermVectors(true);
                ft.setStoreTermVectorOffsets(true);
                ft.setStoreTermVectorPositions(true);
            }

            int b = 0;

            // iterate over existing documents
            for (GroupDocs<Object> group : topGroups.groups) {
                assert group.totalHits.relation == TotalHits.Relation.EQUAL_TO;
                long totalHits = group.totalHits.value;
                double testSize = totalHits * testRatio;
                int tc = 0;
                double cvSize = totalHits * crossValidationRatio;
                int cvc = 0;
                for (ScoreDoc scoreDoc : group.scoreDocs) {

                    // create a new document for indexing
                    Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames);

                    // add it to one of the IDXs
                    if (b % 2 == 0 && tc < testSize) {
                        testWriter.addDocument(doc);
                        tc++;
                    } else if (cvc < cvSize) {
                        cvWriter.addDocument(doc);
                        cvc++;
                    } else {
                        trainingWriter.addDocument(doc);
                    }
                    b++;
                }
            }
            // commit
            testWriter.commit();
            cvWriter.commit();
            trainingWriter.commit();

            // merge
            testWriter.forceMerge(3);
            cvWriter.forceMerge(3);
            trainingWriter.forceMerge(3);
        } catch (Exception e) {
            throw new IOException(e);
        } finally {
            // close IWs
            testWriter.close();
            cvWriter.close();
            trainingWriter.close();
            originalIndex.close();
        }
    }

    private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames)
            throws IOException {
        Document doc = new Document();
        Document document = originalIndex.document(scoreDoc.doc);
        if (fieldNames != null && fieldNames.length > 0) {
            for (String fieldName : fieldNames) {
                IndexableField field = document.getField(fieldName);
                if (field != null) {
                    doc.add(new Field(fieldName, field.stringValue(), ft));
                }
            }
        } else {
            for (IndexableField field : document.getFields()) {
                if (field.readerValue() != null) {
                    doc.add(new Field(field.name(), field.readerValue(), ft));
                } else if (field.binaryValue() != null) {
                    doc.add(new Field(field.name(), field.binaryValue(), ft));
                } else if (field.stringValue() != null) {
                    doc.add(new Field(field.name(), field.stringValue(), ft));
                } else if (field.numericValue() != null) {
                    doc.add(new Field(field.name(), field.numericValue().toString(), ft));
                }
            }
        }
        return doc;
    }

}