com.tamingtext.qa.PassageRankingComponent.java Source code

Java tutorial

Introduction

Here is the source code for com.tamingtext.qa.PassageRankingComponent.java

Source

/*
 * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 * -------------------
 * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
 * http://www.manning.com/ingersoll
 */

package com.tamingtext.qa;

import com.tamingtext.texttamer.solr.NameFilter;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.index.TermVectorMapper;
import org.apache.lucene.index.TermVectorOffsetInfo;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.PriorityQueue;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.PluginInfo;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.SearchComponent;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.DocList;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.plugin.PluginInfoInitialized;
import org.apache.solr.util.plugin.SolrCoreAware;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

/**
 * Given a SpanQuery, get windows around the matches and rank those results
 */
public class PassageRankingComponent extends SearchComponent
        implements PluginInfoInitialized, SolrCoreAware, QAParams {
    private transient static Logger log = LoggerFactory.getLogger(PassageRankingComponent.class);

    static final String NE_PREFIX_LOWER = NameFilter.NE_PREFIX.toLowerCase();

    public static final int DEFAULT_PRIMARY_WINDOW_SIZE = 25;
    public static final int DEFAULT_ADJACENT_WINDOW_SIZE = 25;
    public static final int DEFAULT_SECONDARY_WINDOW_SIZE = 25;

    public static final float DEFAULT_ADJACENT_WEIGHT = 0.5f;
    public static final float DEFAULT_SECOND_ADJACENT_WEIGHT = 0.25f;
    public static final float DEFAULT_BIGRAM_WEIGHT = 1.0f;

    @Override
    public void init(PluginInfo pluginInfo) {

    }

    @Override
    public void inform(SolrCore solrCore) {

    }

    @Override
    public void prepare(ResponseBuilder rb) throws IOException {
        SolrParams params = rb.req.getParams();
        if (!params.getBool(COMPONENT_NAME, false)) {
            return;
        }

    }

    @Override
    public void process(ResponseBuilder rb) throws IOException {
        SolrParams params = rb.req.getParams();
        if (!params.getBool(COMPONENT_NAME, false)) {
            return;
        }
        Query origQuery = rb.getQuery();
        //TODO: longer term, we don't have to be a span query, we could re-analyze the document
        if (origQuery != null) {
            if (origQuery instanceof SpanNearQuery == false) {
                throw new SolrException(SolrException.ErrorCode.SERVER_ERROR,
                        "Illegal query type.  The incoming query must be a Lucene SpanNearQuery and it was a "
                                + origQuery.getClass().getName());
            }
            SpanNearQuery sQuery = (SpanNearQuery) origQuery;
            SolrIndexSearcher searcher = rb.req.getSearcher();
            IndexReader reader = searcher.getIndexReader();
            Spans spans = sQuery.getSpans(reader);
            //Assumes the query is a SpanQuery
            //Build up the query term weight map and the bi-gram
            Map<String, Float> termWeights = new HashMap<String, Float>();
            Map<String, Float> bigramWeights = new HashMap<String, Float>();
            createWeights(params.get(CommonParams.Q), sQuery, termWeights, bigramWeights, reader);
            float adjWeight = params.getFloat(ADJACENT_WEIGHT, DEFAULT_ADJACENT_WEIGHT);
            float secondAdjWeight = params.getFloat(SECOND_ADJ_WEIGHT, DEFAULT_SECOND_ADJACENT_WEIGHT);
            float bigramWeight = params.getFloat(BIGRAM_WEIGHT, DEFAULT_BIGRAM_WEIGHT);
            //get the passages
            int primaryWindowSize = params.getInt(QAParams.PRIMARY_WINDOW_SIZE, DEFAULT_PRIMARY_WINDOW_SIZE);
            int adjacentWindowSize = params.getInt(QAParams.ADJACENT_WINDOW_SIZE, DEFAULT_ADJACENT_WINDOW_SIZE);
            int secondaryWindowSize = params.getInt(QAParams.SECONDARY_WINDOW_SIZE, DEFAULT_SECONDARY_WINDOW_SIZE);
            WindowBuildingTVM tvm = new WindowBuildingTVM(primaryWindowSize, adjacentWindowSize,
                    secondaryWindowSize);
            PassagePriorityQueue rankedPassages = new PassagePriorityQueue();
            //intersect w/ doclist
            DocList docList = rb.getResults().docList;
            while (spans.next() == true) {
                //build up the window
                if (docList.exists(spans.doc())) {
                    tvm.spanStart = spans.start();
                    tvm.spanEnd = spans.end();
                    reader.getTermFreqVector(spans.doc(), sQuery.getField(), tvm);
                    //The entries map contains the window, do some ranking of it
                    if (tvm.passage.terms.isEmpty() == false) {
                        log.debug("Candidate: Doc: {} Start: {} End: {} ",
                                new Object[] { spans.doc(), spans.start(), spans.end() });
                    }
                    tvm.passage.lDocId = spans.doc();
                    tvm.passage.field = sQuery.getField();
                    //score this window
                    try {
                        addPassage(tvm.passage, rankedPassages, termWeights, bigramWeights, adjWeight,
                                secondAdjWeight, bigramWeight);
                    } catch (CloneNotSupportedException e) {
                        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR,
                                "Internal error cloning Passage", e);
                    }
                    //clear out the entries for the next round
                    tvm.passage.clear();
                }
            }
            NamedList qaResp = new NamedList();
            rb.rsp.add("qaResponse", qaResp);
            int rows = params.getInt(QA_ROWS, 5);

            SchemaField uniqField = rb.req.getSchema().getUniqueKeyField();
            if (rankedPassages.size() > 0) {
                int size = Math.min(rows, rankedPassages.size());
                Set<String> fields = new HashSet<String>();
                for (int i = size - 1; i >= 0; i--) {
                    Passage passage = rankedPassages.pop();
                    if (passage != null) {
                        NamedList passNL = new NamedList();
                        qaResp.add(("answer"), passNL);
                        String idName;
                        String idValue;
                        if (uniqField != null) {
                            idName = uniqField.getName();
                            fields.add(idName);
                            fields.add(passage.field);//prefetch this now, so that it is cached
                            idValue = searcher.doc(passage.lDocId, fields).get(idName);
                        } else {
                            idName = "luceneDocId";
                            idValue = String.valueOf(passage.lDocId);
                        }
                        passNL.add(idName, idValue);
                        passNL.add("field", passage.field);
                        //get the window
                        String fldValue = searcher.doc(passage.lDocId, fields).get(passage.field);
                        if (fldValue != null) {
                            //get the window of words to display, we don't use the passage window, as that is based on the term vector
                            int start = passage.terms.first().start;//use the offsets
                            int end = passage.terms.last().end;
                            if (start >= 0 && start < fldValue.length() && end >= 0 && end < fldValue.length()) {
                                passNL.add("window",
                                        fldValue.substring(start, end + passage.terms.last().term.length()));
                            } else {
                                log.debug("Passage does not have correct offset information");
                                passNL.add("window", fldValue);//we don't have offsets, or they are incorrect, return the whole field value
                            }
                        }
                    } else {
                        break;
                    }
                }
            }
        }

    }

    protected float scoreTerms(SortedSet<WindowTerm> terms, Map<String, Float> termWeights, Set<String> covered) {
        float score = 0f;
        for (WindowTerm wTerm : terms) {
            Float tw = (Float) termWeights.get(wTerm.term);
            if (tw != null && !covered.contains(wTerm.term)) {
                score += tw.floatValue();
                covered.add(wTerm.term);
            }
        }

        return (score);
    }

    protected float scoreBigrams(SortedSet<WindowTerm> bigrams, Map<String, Float> bigramWeights,
            Set<String> covered) {
        float result = 0;
        for (WindowTerm bigram : bigrams) {
            Float tw = (Float) bigramWeights.get(bigram.term);
            if (tw != null && !covered.contains(bigram.term)) {
                result += tw.floatValue();
                covered.add(bigram.term);
            }
        }
        return result;
    }

    /**
     * A fairly straightforward and simple scoring approach based on http://trec.nist.gov/pubs/trec8/papers/att-trec8.pdf.
     * <br/>
     * Score the {@link com.tamingtext.qa.PassageRankingComponent.Passage} as the sum of:
     * <ul>
     * <li>The sum of the IDF values for the primary window terms ({@link com.tamingtext.qa.PassageRankingComponent.Passage#terms}</li>
     * <li>The sum of the weights of the terms of the adjacent window ({@link com.tamingtext.qa.PassageRankingComponent.Passage#prevTerms} and {@link com.tamingtext.qa.PassageRankingComponent.Passage#followTerms}) * adjWeight</li>
     * <li>The sum of the weights terms of the second adjacent window ({@link com.tamingtext.qa.PassageRankingComponent.Passage#secPrevTerms} and {@link com.tamingtext.qa.PassageRankingComponent.Passage#secFollowTerms}) * secondAdjWeight</li>
     * <li>The sum of the weights of any bigram matches for the primary window * biWeight</li>
     * </ul>
     * In laymen's terms, this is a decay function that gives higher scores to matching terms that are closer to the anchor
     * term  (where the query matched, in the middle of the window) than those that are further away.
     *
     * @param p               The {@link com.tamingtext.qa.PassageRankingComponent.Passage} to score
     * @param termWeights     The weights of the terms, key is the term, value is the inverse doc frequency (or other weight)
     * @param bigramWeights   The weights of the bigrams, key is the bigram, value is the weight
     * @param adjWeight       The weight to be applied to the adjacent window score
     * @param secondAdjWeight The weight to be applied to the secondary adjacent window score
     * @param biWeight        The weight to be applied to the bigram window score
     * @return The score of passage
     */
    //<start id="qa.scorePassage"/>
    protected float scorePassage(Passage p, Map<String, Float> termWeights, Map<String, Float> bigramWeights,
            float adjWeight, float secondAdjWeight, float biWeight) {
        Set<String> covered = new HashSet<String>();
        float termScore = scoreTerms(p.terms, termWeights, covered);//<co id="prc.main"/>
        float adjScore = scoreTerms(p.prevTerms, termWeights, covered)
                + scoreTerms(p.followTerms, termWeights, covered);//<co id="prc.adj"/>
        float secondScore = scoreTerms(p.secPrevTerms, termWeights, covered)
                + scoreTerms(p.secFollowTerms, termWeights, covered);//<co id="prc.sec"/>
        //Give a bonus for bigram matches in the main window, could also
        float bigramScore = scoreBigrams(p.bigrams, bigramWeights, covered);//<co id="prc.bigrams"/>
        float score = termScore + (adjWeight * adjScore) + (secondAdjWeight * secondScore)
                + (biWeight * bigramScore);//<co id="prc.score"/>
        return (score);
    }
    /*
    <calloutlist>
        <callout arearefs="prc.main"><para>Score the terms in the main window</para></callout>
        <callout arearefs="prc.adj"><para>Score the terms in the window immediately to the left and right of the main window</para></callout>
        <callout arearefs="prc.sec"><para>Score the terms in the windows adjacent to the previous and following windows</para></callout>
        <callout arearefs="prc.bigrams"><para>Score any bigrams in the passage</para></callout>
        <callout arearefs="prc.score"><para>The final score for the passage is a combination of all the scores, each weighted separately.  A bonus is given for any bigram matches.</para></callout>
            
    </calloutlist>
    */
    //<end id="qa.scorePassage"/>

    /**
     * Potentially add the passage to the PriorityQueue.
     *
     * @param p               The passage to add
     * @param pq              The {@link org.apache.lucene.util.PriorityQueue} to add the passage to if it ranks high enough
     * @param termWeights     The weights of the terms
     * @param bigramWeights   The weights of the bigrams
     * @param adjWeight       The weight to be applied to the score of the adjacent window
     * @param secondAdjWeight The weight to be applied to the score of the second adjacent window
     * @param biWeight        The weight to be applied to the score of the bigrams
     * @throws CloneNotSupportedException if not cloneable
     */
    private void addPassage(Passage p, PassagePriorityQueue pq, Map<String, Float> termWeights,
            Map<String, Float> bigramWeights, float adjWeight, float secondAdjWeight, float biWeight)
            throws CloneNotSupportedException {
        p.score = scorePassage(p, termWeights, bigramWeights, adjWeight, secondAdjWeight, biWeight);
        Passage lowest = pq.top();
        if (lowest == null || pq.lessThan(p, lowest) == false || pq.size() < pq.capacity()) {
            //by doing this, we can re-use the Passage object
            Passage cloned = (Passage) p.clone();
            //TODO: Do we care about the overflow?
            pq.insertWithOverflow(cloned);
        }

    }

    protected void createWeights(String origQuery, SpanNearQuery parsedQuery, Map<String, Float> termWeights,
            Map<String, Float> bigramWeights, IndexReader reader) throws IOException {

        SpanQuery[] clauses = parsedQuery.getClauses();
        //we need to recurse through the clauses until we get to SpanTermQuery
        Term lastTerm = null;
        Float lastWeight = null;
        for (int i = 0; i < clauses.length; i++) {
            SpanQuery clause = clauses[i];
            if (clause instanceof SpanTermQuery) {
                Term term = ((SpanTermQuery) clause).getTerm();
                Float weight = calculateWeight(term, reader);
                termWeights.put(term.text(), weight);
                if (lastTerm != null) {//calculate the bi-grams
                    //use the smaller of the two weights
                    if (lastWeight.floatValue() < weight.floatValue()) {
                        bigramWeights.put(lastTerm + "," + term.text(), new Float(lastWeight.floatValue() * 0.25));
                    } else {
                        bigramWeights.put(lastTerm + "," + term.text(), new Float(weight.floatValue() * 0.25));
                    }
                }
                //last
                lastTerm = term;
                lastWeight = weight;
            } else {
                //TODO: handle the other types
                throw new SolrException(SolrException.ErrorCode.SERVER_ERROR,
                        "Unhandled query type: " + clause.getClass().getName());
            }
        }

    }

    protected float calculateWeight(Term term, IndexReader reader) throws IOException {
        //if a term is not in the index, then it's weight is 0
        TermEnum termEnum = reader.terms(term);
        if (termEnum != null && termEnum.term() != null && termEnum.term().equals(term)) {
            return 1.0f / termEnum.docFreq();
        } else {
            log.warn("Couldn't find doc freq for term {}", term);
            return 0;
        }

    }

    class Passage implements Cloneable {
        int lDocId;
        String field;

        float score;
        SortedSet<WindowTerm> terms = new TreeSet<WindowTerm>();
        SortedSet<WindowTerm> prevTerms = new TreeSet<WindowTerm>();
        SortedSet<WindowTerm> followTerms = new TreeSet<WindowTerm>();
        SortedSet<WindowTerm> secPrevTerms = new TreeSet<WindowTerm>();
        SortedSet<WindowTerm> secFollowTerms = new TreeSet<WindowTerm>();
        SortedSet<WindowTerm> bigrams = new TreeSet<WindowTerm>();

        Passage() {
        }

        @Override
        protected Object clone() throws CloneNotSupportedException {
            Passage result = (Passage) super.clone();
            result.terms = new TreeSet<WindowTerm>();
            for (WindowTerm term : terms) {
                result.terms.add((WindowTerm) term.clone());
            }
            result.prevTerms = new TreeSet<WindowTerm>();
            for (WindowTerm term : prevTerms) {
                result.prevTerms.add((WindowTerm) term.clone());
            }
            result.followTerms = new TreeSet<WindowTerm>();
            for (WindowTerm term : followTerms) {
                result.followTerms.add((WindowTerm) term.clone());
            }
            result.secPrevTerms = new TreeSet<WindowTerm>();
            for (WindowTerm term : secPrevTerms) {
                result.secPrevTerms.add((WindowTerm) term.clone());
            }
            result.secFollowTerms = new TreeSet<WindowTerm>();
            for (WindowTerm term : secFollowTerms) {
                result.secFollowTerms.add((WindowTerm) term.clone());
            }
            result.bigrams = new TreeSet<WindowTerm>();
            for (WindowTerm term : bigrams) {
                result.bigrams.add((WindowTerm) term.clone());
            }

            return result;
        }

        public void clear() {
            terms.clear();
            prevTerms.clear();
            followTerms.clear();
            secPrevTerms.clear();
            secPrevTerms.clear();
            bigrams.clear();
        }

    }

    class PassagePriorityQueue extends PriorityQueue<Passage> {

        PassagePriorityQueue() {
            initialize(10);
        }

        PassagePriorityQueue(int maxSize) {
            initialize(maxSize);
        }

        public int capacity() {
            return getHeapArray().length;
        }

        @Override
        public boolean lessThan(Passage passageA, Passage passageB) {
            if (passageA.score == passageB.score)
                return passageA.lDocId > passageB.lDocId;
            else
                return passageA.score < passageB.score;
        }
    }

    //Not thread-safe, but should be lightweight to build

    /**
     * The PassageRankingTVM is a Lucene TermVectorMapper that builds a five different windows around a matching term.
     * This Window can then be used to rank the passages
     */
    class WindowBuildingTVM extends TermVectorMapper {
        //spanStart and spanEnd are the start and positions of where the match occurred in the document
        //from these values, we can calculate the windows
        int spanStart, spanEnd;
        Passage passage;
        private int primaryWS, adjWS, secWS;

        public WindowBuildingTVM(int primaryWindowSize, int adjacentWindowSize, int secondaryWindowSize) {
            this.primaryWS = primaryWindowSize;
            this.adjWS = adjacentWindowSize;
            this.secWS = secondaryWindowSize;
            passage = new Passage();//reuse the passage, since it will be cloned if it makes it onto the priority queue
        }

        public void map(String term, int frequency, TermVectorOffsetInfo[] offsets, int[] positions) {
            if (positions.length > 0 && term.startsWith(NameFilter.NE_PREFIX) == false
                    && term.startsWith(NE_PREFIX_LOWER) == false) {//filter out the types, as we don't need them here
                //construct the windows, which means we need a bunch of bracketing variables to know what window we are in

                //start and end of the primary window
                int primStart = spanStart - primaryWS;
                int primEnd = spanEnd + primaryWS;
                // stores the start and end of the adjacent previous and following
                int adjLBStart = primStart - adjWS;
                int adjLBEnd = primStart - 1;//don't overlap
                int adjUBStart = primEnd + 1;//don't o
                int adjUBEnd = primEnd + adjWS;
                //stores the start and end of the secondary previous and the secondary following
                int secLBStart = adjLBStart - secWS;
                int secLBEnd = adjLBStart - 1; //don't overlap the adjacent window
                int secUBStart = adjUBEnd + 1;
                int secUBEnd = adjUBEnd + secWS;
                WindowTerm lastWT = null;
                for (int i = 0; i < positions.length; i++) {//unfortunately, we still have to loop over the positions
                    //we'll make this inclusive of the boundaries, do an upfront check here so we can skip over anything that is outside of all windows
                    if (positions[i] >= secLBStart && positions[i] <= secUBEnd) {
                        //fill in the windows
                        WindowTerm wt;
                        //offsets aren't required, but they are nice to have
                        if (offsets != null) {
                            wt = new WindowTerm(term, positions[i], offsets[i].getStartOffset(),
                                    offsets[i].getEndOffset());
                        } else {
                            wt = new WindowTerm(term, positions[i]);
                        }
                        if (positions[i] >= primStart && positions[i] <= primEnd) {//are we in the primary window
                            passage.terms.add(wt);
                            //we are only going to keep bigrams for the primary window.  You could do it for the other windows, too
                            if (lastWT != null) {
                                WindowTerm bigramWT = new WindowTerm(lastWT.term + "," + term, lastWT.position);//we don't care about offsets for bigrams
                                passage.bigrams.add(bigramWT);
                            }
                            lastWT = wt;
                        } else if (positions[i] >= secLBStart && positions[i] <= secLBEnd) {//are we in the secondary previous window?
                            passage.secPrevTerms.add(wt);
                        } else if (positions[i] >= secUBStart && positions[i] <= secUBEnd) {//are we in the secondary following window?
                            passage.secFollowTerms.add(wt);
                        } else if (positions[i] >= adjLBStart && positions[i] <= adjLBEnd) {//are we in the adjacent previous window?
                            passage.prevTerms.add(wt);
                        } else if (positions[i] >= adjUBStart && positions[i] <= adjUBEnd) {//are we in the adjacent following window?
                            passage.followTerms.add(wt);
                        }
                    }
                }
            }
        }

        public void setExpectations(String field, int numTerms, boolean storeOffsets, boolean storePositions) {
            // do nothing for this example
            //See also the PositionBasedTermVectorMapper.
        }

    }

    class WindowTerm implements Cloneable, Comparable<WindowTerm> {
        String term;
        int position;
        int start, end = -1;

        WindowTerm(String term, int position, int startOffset, int endOffset) {
            this.term = term;
            this.position = position;
            this.start = startOffset;
            this.end = endOffset;
        }

        public WindowTerm(String s, int position) {
            this.term = s;
            this.position = position;
        }

        @Override
        protected Object clone() throws CloneNotSupportedException {
            return super.clone();
        }

        @Override
        public int compareTo(WindowTerm other) {
            int result = position - other.position;
            if (result == 0) {
                result = term.compareTo(other.term);
            }
            return result;
        }

        @Override
        public String toString() {
            return "WindowEntry{" + "term='" + term + '\'' + ", position=" + position + '}';
        }
    }

    @Override
    public String getDescription() {
        return "Question Answering PassageRanking";
    }

    @Override
    public String getVersion() {
        return "$Revision:$";
    }

    @Override
    public String getSourceId() {
        return "$Id:$";
    }

    @Override
    public String getSource() {
        return "$URL:$";
    }
}