org.apache.lucene.queries.payloads.SpanPayloadCheckQuery.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.lucene.queries.payloads.SpanPayloadCheckQuery.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.queries.payloads;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafSimScorer;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.spans.FilterSpans;
import org.apache.lucene.search.spans.FilterSpans.AcceptStatus;
import org.apache.lucene.search.spans.SpanCollector;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanScorer;
import org.apache.lucene.search.spans.SpanWeight;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.BytesRef;

/**
 * Only return those matches that have a specific payload at the given position.
 */
public class SpanPayloadCheckQuery extends SpanQuery {

    protected final List<BytesRef> payloadToMatch;
    protected final SpanQuery match;

    /**
     * @param match The underlying {@link org.apache.lucene.search.spans.SpanQuery} to check
     * @param payloadToMatch The {@link java.util.List} of payloads to match
     */
    public SpanPayloadCheckQuery(SpanQuery match, List<BytesRef> payloadToMatch) {
        this.match = match;
        this.payloadToMatch = payloadToMatch;
    }

    @Override
    public String getField() {
        return match.getField();
    }

    @Override
    public SpanWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        SpanWeight matchWeight = match.createWeight(searcher, scoreMode, boost);
        return new SpanPayloadCheckWeight(searcher, scoreMode.needsScores() ? getTermStates(matchWeight) : null,
                matchWeight, boost);
    }

    @Override
    public Query rewrite(IndexReader reader) throws IOException {
        Query matchRewritten = match.rewrite(reader);
        if (match != matchRewritten && matchRewritten instanceof SpanQuery) {
            return new SpanPayloadCheckQuery((SpanQuery) matchRewritten, payloadToMatch);
        }
        return super.rewrite(reader);
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(match.getField())) {
            match.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
        }
    }

    /**
     * Weight that pulls its Spans using a PayloadSpanCollector
     */
    public class SpanPayloadCheckWeight extends SpanWeight {

        final SpanWeight matchWeight;

        public SpanPayloadCheckWeight(IndexSearcher searcher, Map<Term, TermStates> termStates,
                SpanWeight matchWeight, float boost) throws IOException {
            super(SpanPayloadCheckQuery.this, searcher, termStates, boost);
            this.matchWeight = matchWeight;
        }

        @Override
        public void extractTerms(Set<Term> terms) {
            matchWeight.extractTerms(terms);
        }

        @Override
        public void extractTermStates(Map<Term, TermStates> contexts) {
            matchWeight.extractTermStates(contexts);
        }

        @Override
        public Spans getSpans(final LeafReaderContext context, Postings requiredPostings) throws IOException {
            final PayloadChecker collector = new PayloadChecker();
            Spans matchSpans = matchWeight.getSpans(context, requiredPostings.atLeast(Postings.PAYLOADS));
            return (matchSpans == null) ? null : new FilterSpans(matchSpans) {
                @Override
                protected AcceptStatus accept(Spans candidate) throws IOException {
                    collector.reset();
                    candidate.collect(collector);
                    return collector.match();
                }
            };
        }

        @Override
        public SpanScorer scorer(LeafReaderContext context) throws IOException {
            if (field == null)
                return null;

            Terms terms = context.reader().terms(field);
            if (terms != null && terms.hasPositions() == false) {
                throw new IllegalStateException("field \"" + field
                        + "\" was indexed without position data; cannot run SpanQuery (query=" + parentQuery + ")");
            }

            final Spans spans = getSpans(context, Postings.PAYLOADS);
            if (spans == null) {
                return null;
            }
            final LeafSimScorer docScorer = getSimScorer(context);
            return new SpanScorer(this, spans, docScorer);
        }

        @Override
        public boolean isCacheable(LeafReaderContext ctx) {
            return matchWeight.isCacheable(ctx);
        }

    }

    private class PayloadChecker implements SpanCollector {

        int upto = 0;
        boolean matches = true;

        @Override
        public void collectLeaf(PostingsEnum postings, int position, Term term) throws IOException {
            if (!matches)
                return;
            if (upto >= payloadToMatch.size()) {
                matches = false;
                return;
            }
            BytesRef payload = postings.getPayload();
            if (payloadToMatch.get(upto) == null) {
                matches = payload == null;
                upto++;
                return;
            }
            if (payload == null) {
                matches = false;
                upto++;
                return;
            }
            matches = payloadToMatch.get(upto).bytesEquals(payload);
            upto++;
        }

        AcceptStatus match() {
            return matches && upto == payloadToMatch.size() ? AcceptStatus.YES : AcceptStatus.NO;
        }

        @Override
        public void reset() {
            this.upto = 0;
            this.matches = true;
        }
    }

    @Override
    public String toString(String field) {
        StringBuilder buffer = new StringBuilder();
        buffer.append("SpanPayloadCheckQuery(");
        buffer.append(match.toString(field));
        buffer.append(", payloadRef: ");
        for (BytesRef bytes : payloadToMatch) {
            buffer.append(Term.toString(bytes));
            buffer.append(';');
        }
        buffer.append(")");
        return buffer.toString();
    }

    @Override
    public boolean equals(Object other) {
        return sameClassAs(other) && payloadToMatch.equals(((SpanPayloadCheckQuery) other).payloadToMatch)
                && match.equals(((SpanPayloadCheckQuery) other).match);
    }

    @Override
    public int hashCode() {
        int result = classHash();
        result = 31 * result + Objects.hashCode(match);
        result = 31 * result + Objects.hashCode(payloadToMatch);
        return result;
    }
}