lucene.security.search.DocumentVisibilityFilter.java Source code

Java tutorial

Introduction

Here is the source code for lucene.security.search.DocumentVisibilityFilter.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package lucene.security.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

import lucene.security.DocumentAuthorizations;
import lucene.security.DocumentVisibility;
import lucene.security.DocumentVisibilityEvaluator;
import lucene.security.search.DocumentVisibilityFilterCacheStrategy.Builder;

import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Filter;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;

public class DocumentVisibilityFilter extends Filter {

    private static final Comparator<DocIdSetIterator> COMPARATOR = new Comparator<DocIdSetIterator>() {
        @Override
        public int compare(DocIdSetIterator o1, DocIdSetIterator o2) {
            int docID1 = o1.docID();
            int docID2 = o2.docID();
            return docID1 - docID2;
        }
    };

    private final String _fieldName;
    private final DocumentAuthorizations _authorizations;
    private final DocumentVisibilityFilterCacheStrategy _filterCacheStrategy;

    public DocumentVisibilityFilter(String fieldName, DocumentAuthorizations authorizations,
            DocumentVisibilityFilterCacheStrategy filterCacheStrategy) {
        _fieldName = fieldName;
        _authorizations = authorizations;
        _filterCacheStrategy = filterCacheStrategy;
    }

    @Override
    public String toString() {
        return "DocumentVisibilityFilter [_fieldName=" + _fieldName + ", _authorizations=" + _authorizations
                + ", _filterCacheStrategy=" + _filterCacheStrategy + "]";
    }

    @Override
    public DocIdSet getDocIdSet(AtomicReaderContext context, Bits acceptDocs) throws IOException {
        AtomicReader reader = context.reader();
        List<DocIdSet> list = new ArrayList<DocIdSet>();

        Fields fields = reader.fields();
        Terms terms = fields.terms(_fieldName);
        if (terms == null) {
            if (acceptDocs instanceof DocIdSet) {
                return (DocIdSet) acceptDocs;
            } else {
                return wrap(acceptDocs);
            }
        }
        TermsEnum iterator = terms.iterator(null);
        BytesRef bytesRef;
        DocumentVisibilityEvaluator visibilityEvaluator = new DocumentVisibilityEvaluator(_authorizations);
        while ((bytesRef = iterator.next()) != null) {
            if (isVisible(visibilityEvaluator, bytesRef)) {
                DocIdSet docIdSet = _filterCacheStrategy.getDocIdSet(_fieldName, bytesRef, reader);
                if (docIdSet != null) {
                    list.add(docIdSet);
                } else {
                    DocsEnum docsEnum = iterator.docs(acceptDocs, null);
                    list.add(buildCache(reader, docsEnum, bytesRef));
                }
            }
        }
        return getLogicalOr(list);
    }

    private DocIdSet buildCache(AtomicReader reader, DocIdSetIterator it, BytesRef bytesRef) throws IOException {
        Builder builder = _filterCacheStrategy.createBuilder(_fieldName, bytesRef, reader);
        builder.or(it);
        return builder.getDocIdSet();
    }

    private boolean isVisible(DocumentVisibilityEvaluator visibilityEvaluator, BytesRef bytesRef)
            throws IOException {
        DocumentVisibility visibility = new DocumentVisibility(trim(bytesRef));
        return visibilityEvaluator.evaluate(visibility);
    }

    private byte[] trim(BytesRef bytesRef) {
        byte[] buf = new byte[bytesRef.length];
        System.arraycopy(bytesRef.bytes, bytesRef.offset, buf, 0, bytesRef.length);
        return buf;
    }

    private DocIdSet wrap(Bits acceptDocs) {
        if (acceptDocs == null) {
            return null;
        }
        throw new RuntimeException("not implemented " + acceptDocs);
    }

    public static DocIdSet getLogicalOr(DocIdSet... list) throws IOException {
        return getLogicalOr(Arrays.asList(list));
    }

    public static DocIdSet getLogicalOr(final List<DocIdSet> list) throws IOException {
        if (list.size() == 0) {
            return DocIdSet.EMPTY_DOCIDSET;
        }
        if (list.size() == 1) {
            DocIdSet docIdSet = list.get(0);
            Bits bits = docIdSet.bits();
            if (bits == null) {
                throw new IOException("Bits are not allowed to be null for DocIdSet [" + docIdSet + "].");
            }
            return docIdSet;
        }
        int index = 0;
        final Bits[] bitsArray = new Bits[list.size()];
        int length = -1;
        for (DocIdSet docIdSet : list) {
            Bits bits = docIdSet.bits();
            if (bits == null) {
                throw new IOException("Bits are not allowed to be null for DocIdSet [" + docIdSet + "].");
            }
            bitsArray[index] = bits;
            index++;
            if (length < 0) {
                length = bits.length();
            } else if (length != bits.length()) {
                throw new IOException(
                        "Bits length need to be the same [" + length + "] and [" + bits.length() + "]");
            }
        }
        final int len = length;
        return new DocIdSet() {

            @Override
            public Bits bits() throws IOException {
                return new Bits() {

                    @Override
                    public boolean get(int index) {
                        for (int i = 0; i < bitsArray.length; i++) {
                            if (bitsArray[i].get(index)) {
                                return true;
                            }
                        }
                        return false;
                    }

                    @Override
                    public int length() {
                        return len;
                    }

                };
            }

            @Override
            public boolean isCacheable() {
                return true;
            }

            @Override
            public DocIdSetIterator iterator() throws IOException {
                final DocIdSetIterator[] docIdSetIteratorArray = new DocIdSetIterator[list.size()];
                long c = 0;
                int index = 0;
                for (DocIdSet docIdSet : list) {
                    DocIdSetIterator iterator = docIdSet.iterator();
                    iterator.nextDoc();
                    docIdSetIteratorArray[index] = iterator;
                    c += iterator.cost();
                    index++;
                }
                final long cost = c;
                return new DocIdSetIterator() {

                    private int _docId = -1;

                    @Override
                    public int advance(int target) throws IOException {
                        callAdvanceOnAllThatAreBehind(target);
                        Arrays.sort(docIdSetIteratorArray, COMPARATOR);
                        DocIdSetIterator iterator = docIdSetIteratorArray[0];
                        return _docId = iterator.docID();
                    }

                    private void callAdvanceOnAllThatAreBehind(int target) throws IOException {
                        for (int i = 0; i < docIdSetIteratorArray.length; i++) {
                            DocIdSetIterator iterator = docIdSetIteratorArray[i];
                            if (iterator.docID() < target) {
                                iterator.advance(target);
                            }
                        }
                    }

                    @Override
                    public int nextDoc() throws IOException {
                        return advance(_docId + 1);
                    }

                    @Override
                    public int docID() {
                        return _docId;
                    }

                    @Override
                    public long cost() {
                        return cost;
                    }

                };
            }
        };
    }

}