Java tutorial
/* * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ------------------- * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit * http://www.manning.com/ingersoll */ package com.tamingtext.qa; import com.tamingtext.texttamer.solr.NameFilter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermEnum; import org.apache.lucene.index.TermVectorMapper; import org.apache.lucene.index.TermVectorOffsetInfo; import org.apache.lucene.search.Query; import org.apache.lucene.search.spans.SpanNearQuery; import org.apache.lucene.search.spans.SpanQuery; import org.apache.lucene.search.spans.SpanTermQuery; import org.apache.lucene.search.spans.Spans; import org.apache.lucene.util.PriorityQueue; import org.apache.solr.common.SolrException; import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.core.PluginInfo; import org.apache.solr.core.SolrCore; import org.apache.solr.handler.component.ResponseBuilder; import org.apache.solr.handler.component.SearchComponent; import org.apache.solr.schema.SchemaField; import org.apache.solr.search.DocList; import org.apache.solr.search.SolrIndexSearcher; import org.apache.solr.util.plugin.PluginInfoInitialized; import org.apache.solr.util.plugin.SolrCoreAware; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; /** * Given a SpanQuery, get windows around the matches and rank those results */ public class PassageRankingComponent extends SearchComponent implements PluginInfoInitialized, SolrCoreAware, QAParams { private transient static Logger log = LoggerFactory.getLogger(PassageRankingComponent.class); static final String NE_PREFIX_LOWER = NameFilter.NE_PREFIX.toLowerCase(); public static final int DEFAULT_PRIMARY_WINDOW_SIZE = 25; public static final int DEFAULT_ADJACENT_WINDOW_SIZE = 25; public static final int DEFAULT_SECONDARY_WINDOW_SIZE = 25; public static final float DEFAULT_ADJACENT_WEIGHT = 0.5f; public static final float DEFAULT_SECOND_ADJACENT_WEIGHT = 0.25f; public static final float DEFAULT_BIGRAM_WEIGHT = 1.0f; @Override public void init(PluginInfo pluginInfo) { } @Override public void inform(SolrCore solrCore) { } @Override public void prepare(ResponseBuilder rb) throws IOException { SolrParams params = rb.req.getParams(); if (!params.getBool(COMPONENT_NAME, false)) { return; } } @Override public void process(ResponseBuilder rb) throws IOException { SolrParams params = rb.req.getParams(); if (!params.getBool(COMPONENT_NAME, false)) { return; } Query origQuery = rb.getQuery(); //TODO: longer term, we don't have to be a span query, we could re-analyze the document if (origQuery != null) { if (origQuery instanceof SpanNearQuery == false) { throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Illegal query type. The incoming query must be a Lucene SpanNearQuery and it was a " + origQuery.getClass().getName()); } SpanNearQuery sQuery = (SpanNearQuery) origQuery; SolrIndexSearcher searcher = rb.req.getSearcher(); IndexReader reader = searcher.getIndexReader(); Spans spans = sQuery.getSpans(reader); //Assumes the query is a SpanQuery //Build up the query term weight map and the bi-gram Map<String, Float> termWeights = new HashMap<String, Float>(); Map<String, Float> bigramWeights = new HashMap<String, Float>(); createWeights(params.get(CommonParams.Q), sQuery, termWeights, bigramWeights, reader); float adjWeight = params.getFloat(ADJACENT_WEIGHT, DEFAULT_ADJACENT_WEIGHT); float secondAdjWeight = params.getFloat(SECOND_ADJ_WEIGHT, DEFAULT_SECOND_ADJACENT_WEIGHT); float bigramWeight = params.getFloat(BIGRAM_WEIGHT, DEFAULT_BIGRAM_WEIGHT); //get the passages int primaryWindowSize = params.getInt(QAParams.PRIMARY_WINDOW_SIZE, DEFAULT_PRIMARY_WINDOW_SIZE); int adjacentWindowSize = params.getInt(QAParams.ADJACENT_WINDOW_SIZE, DEFAULT_ADJACENT_WINDOW_SIZE); int secondaryWindowSize = params.getInt(QAParams.SECONDARY_WINDOW_SIZE, DEFAULT_SECONDARY_WINDOW_SIZE); WindowBuildingTVM tvm = new WindowBuildingTVM(primaryWindowSize, adjacentWindowSize, secondaryWindowSize); PassagePriorityQueue rankedPassages = new PassagePriorityQueue(); //intersect w/ doclist DocList docList = rb.getResults().docList; while (spans.next() == true) { //build up the window if (docList.exists(spans.doc())) { tvm.spanStart = spans.start(); tvm.spanEnd = spans.end(); reader.getTermFreqVector(spans.doc(), sQuery.getField(), tvm); //The entries map contains the window, do some ranking of it if (tvm.passage.terms.isEmpty() == false) { log.debug("Candidate: Doc: {} Start: {} End: {} ", new Object[] { spans.doc(), spans.start(), spans.end() }); } tvm.passage.lDocId = spans.doc(); tvm.passage.field = sQuery.getField(); //score this window try { addPassage(tvm.passage, rankedPassages, termWeights, bigramWeights, adjWeight, secondAdjWeight, bigramWeight); } catch (CloneNotSupportedException e) { throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Internal error cloning Passage", e); } //clear out the entries for the next round tvm.passage.clear(); } } NamedList qaResp = new NamedList(); rb.rsp.add("qaResponse", qaResp); int rows = params.getInt(QA_ROWS, 5); SchemaField uniqField = rb.req.getSchema().getUniqueKeyField(); if (rankedPassages.size() > 0) { int size = Math.min(rows, rankedPassages.size()); Set<String> fields = new HashSet<String>(); for (int i = size - 1; i >= 0; i--) { Passage passage = rankedPassages.pop(); if (passage != null) { NamedList passNL = new NamedList(); qaResp.add(("answer"), passNL); String idName; String idValue; if (uniqField != null) { idName = uniqField.getName(); fields.add(idName); fields.add(passage.field);//prefetch this now, so that it is cached idValue = searcher.doc(passage.lDocId, fields).get(idName); } else { idName = "luceneDocId"; idValue = String.valueOf(passage.lDocId); } passNL.add(idName, idValue); passNL.add("field", passage.field); //get the window String fldValue = searcher.doc(passage.lDocId, fields).get(passage.field); if (fldValue != null) { //get the window of words to display, we don't use the passage window, as that is based on the term vector int start = passage.terms.first().start;//use the offsets int end = passage.terms.last().end; if (start >= 0 && start < fldValue.length() && end >= 0 && end < fldValue.length()) { passNL.add("window", fldValue.substring(start, end + passage.terms.last().term.length())); } else { log.debug("Passage does not have correct offset information"); passNL.add("window", fldValue);//we don't have offsets, or they are incorrect, return the whole field value } } } else { break; } } } } } protected float scoreTerms(SortedSet<WindowTerm> terms, Map<String, Float> termWeights, Set<String> covered) { float score = 0f; for (WindowTerm wTerm : terms) { Float tw = (Float) termWeights.get(wTerm.term); if (tw != null && !covered.contains(wTerm.term)) { score += tw.floatValue(); covered.add(wTerm.term); } } return (score); } protected float scoreBigrams(SortedSet<WindowTerm> bigrams, Map<String, Float> bigramWeights, Set<String> covered) { float result = 0; for (WindowTerm bigram : bigrams) { Float tw = (Float) bigramWeights.get(bigram.term); if (tw != null && !covered.contains(bigram.term)) { result += tw.floatValue(); covered.add(bigram.term); } } return result; } /** * A fairly straightforward and simple scoring approach based on http://trec.nist.gov/pubs/trec8/papers/att-trec8.pdf. * <br/> * Score the {@link com.tamingtext.qa.PassageRankingComponent.Passage} as the sum of: * <ul> * <li>The sum of the IDF values for the primary window terms ({@link com.tamingtext.qa.PassageRankingComponent.Passage#terms}</li> * <li>The sum of the weights of the terms of the adjacent window ({@link com.tamingtext.qa.PassageRankingComponent.Passage#prevTerms} and {@link com.tamingtext.qa.PassageRankingComponent.Passage#followTerms}) * adjWeight</li> * <li>The sum of the weights terms of the second adjacent window ({@link com.tamingtext.qa.PassageRankingComponent.Passage#secPrevTerms} and {@link com.tamingtext.qa.PassageRankingComponent.Passage#secFollowTerms}) * secondAdjWeight</li> * <li>The sum of the weights of any bigram matches for the primary window * biWeight</li> * </ul> * In laymen's terms, this is a decay function that gives higher scores to matching terms that are closer to the anchor * term (where the query matched, in the middle of the window) than those that are further away. * * @param p The {@link com.tamingtext.qa.PassageRankingComponent.Passage} to score * @param termWeights The weights of the terms, key is the term, value is the inverse doc frequency (or other weight) * @param bigramWeights The weights of the bigrams, key is the bigram, value is the weight * @param adjWeight The weight to be applied to the adjacent window score * @param secondAdjWeight The weight to be applied to the secondary adjacent window score * @param biWeight The weight to be applied to the bigram window score * @return The score of passage */ //<start id="qa.scorePassage"/> protected float scorePassage(Passage p, Map<String, Float> termWeights, Map<String, Float> bigramWeights, float adjWeight, float secondAdjWeight, float biWeight) { Set<String> covered = new HashSet<String>(); float termScore = scoreTerms(p.terms, termWeights, covered);//<co id="prc.main"/> float adjScore = scoreTerms(p.prevTerms, termWeights, covered) + scoreTerms(p.followTerms, termWeights, covered);//<co id="prc.adj"/> float secondScore = scoreTerms(p.secPrevTerms, termWeights, covered) + scoreTerms(p.secFollowTerms, termWeights, covered);//<co id="prc.sec"/> //Give a bonus for bigram matches in the main window, could also float bigramScore = scoreBigrams(p.bigrams, bigramWeights, covered);//<co id="prc.bigrams"/> float score = termScore + (adjWeight * adjScore) + (secondAdjWeight * secondScore) + (biWeight * bigramScore);//<co id="prc.score"/> return (score); } /* <calloutlist> <callout arearefs="prc.main"><para>Score the terms in the main window</para></callout> <callout arearefs="prc.adj"><para>Score the terms in the window immediately to the left and right of the main window</para></callout> <callout arearefs="prc.sec"><para>Score the terms in the windows adjacent to the previous and following windows</para></callout> <callout arearefs="prc.bigrams"><para>Score any bigrams in the passage</para></callout> <callout arearefs="prc.score"><para>The final score for the passage is a combination of all the scores, each weighted separately. A bonus is given for any bigram matches.</para></callout> </calloutlist> */ //<end id="qa.scorePassage"/> /** * Potentially add the passage to the PriorityQueue. * * @param p The passage to add * @param pq The {@link org.apache.lucene.util.PriorityQueue} to add the passage to if it ranks high enough * @param termWeights The weights of the terms * @param bigramWeights The weights of the bigrams * @param adjWeight The weight to be applied to the score of the adjacent window * @param secondAdjWeight The weight to be applied to the score of the second adjacent window * @param biWeight The weight to be applied to the score of the bigrams * @throws CloneNotSupportedException if not cloneable */ private void addPassage(Passage p, PassagePriorityQueue pq, Map<String, Float> termWeights, Map<String, Float> bigramWeights, float adjWeight, float secondAdjWeight, float biWeight) throws CloneNotSupportedException { p.score = scorePassage(p, termWeights, bigramWeights, adjWeight, secondAdjWeight, biWeight); Passage lowest = pq.top(); if (lowest == null || pq.lessThan(p, lowest) == false || pq.size() < pq.capacity()) { //by doing this, we can re-use the Passage object Passage cloned = (Passage) p.clone(); //TODO: Do we care about the overflow? pq.insertWithOverflow(cloned); } } protected void createWeights(String origQuery, SpanNearQuery parsedQuery, Map<String, Float> termWeights, Map<String, Float> bigramWeights, IndexReader reader) throws IOException { SpanQuery[] clauses = parsedQuery.getClauses(); //we need to recurse through the clauses until we get to SpanTermQuery Term lastTerm = null; Float lastWeight = null; for (int i = 0; i < clauses.length; i++) { SpanQuery clause = clauses[i]; if (clause instanceof SpanTermQuery) { Term term = ((SpanTermQuery) clause).getTerm(); Float weight = calculateWeight(term, reader); termWeights.put(term.text(), weight); if (lastTerm != null) {//calculate the bi-grams //use the smaller of the two weights if (lastWeight.floatValue() < weight.floatValue()) { bigramWeights.put(lastTerm + "," + term.text(), new Float(lastWeight.floatValue() * 0.25)); } else { bigramWeights.put(lastTerm + "," + term.text(), new Float(weight.floatValue() * 0.25)); } } //last lastTerm = term; lastWeight = weight; } else { //TODO: handle the other types throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unhandled query type: " + clause.getClass().getName()); } } } protected float calculateWeight(Term term, IndexReader reader) throws IOException { //if a term is not in the index, then it's weight is 0 TermEnum termEnum = reader.terms(term); if (termEnum != null && termEnum.term() != null && termEnum.term().equals(term)) { return 1.0f / termEnum.docFreq(); } else { log.warn("Couldn't find doc freq for term {}", term); return 0; } } class Passage implements Cloneable { int lDocId; String field; float score; SortedSet<WindowTerm> terms = new TreeSet<WindowTerm>(); SortedSet<WindowTerm> prevTerms = new TreeSet<WindowTerm>(); SortedSet<WindowTerm> followTerms = new TreeSet<WindowTerm>(); SortedSet<WindowTerm> secPrevTerms = new TreeSet<WindowTerm>(); SortedSet<WindowTerm> secFollowTerms = new TreeSet<WindowTerm>(); SortedSet<WindowTerm> bigrams = new TreeSet<WindowTerm>(); Passage() { } @Override protected Object clone() throws CloneNotSupportedException { Passage result = (Passage) super.clone(); result.terms = new TreeSet<WindowTerm>(); for (WindowTerm term : terms) { result.terms.add((WindowTerm) term.clone()); } result.prevTerms = new TreeSet<WindowTerm>(); for (WindowTerm term : prevTerms) { result.prevTerms.add((WindowTerm) term.clone()); } result.followTerms = new TreeSet<WindowTerm>(); for (WindowTerm term : followTerms) { result.followTerms.add((WindowTerm) term.clone()); } result.secPrevTerms = new TreeSet<WindowTerm>(); for (WindowTerm term : secPrevTerms) { result.secPrevTerms.add((WindowTerm) term.clone()); } result.secFollowTerms = new TreeSet<WindowTerm>(); for (WindowTerm term : secFollowTerms) { result.secFollowTerms.add((WindowTerm) term.clone()); } result.bigrams = new TreeSet<WindowTerm>(); for (WindowTerm term : bigrams) { result.bigrams.add((WindowTerm) term.clone()); } return result; } public void clear() { terms.clear(); prevTerms.clear(); followTerms.clear(); secPrevTerms.clear(); secPrevTerms.clear(); bigrams.clear(); } } class PassagePriorityQueue extends PriorityQueue<Passage> { PassagePriorityQueue() { initialize(10); } PassagePriorityQueue(int maxSize) { initialize(maxSize); } public int capacity() { return getHeapArray().length; } @Override public boolean lessThan(Passage passageA, Passage passageB) { if (passageA.score == passageB.score) return passageA.lDocId > passageB.lDocId; else return passageA.score < passageB.score; } } //Not thread-safe, but should be lightweight to build /** * The PassageRankingTVM is a Lucene TermVectorMapper that builds a five different windows around a matching term. * This Window can then be used to rank the passages */ class WindowBuildingTVM extends TermVectorMapper { //spanStart and spanEnd are the start and positions of where the match occurred in the document //from these values, we can calculate the windows int spanStart, spanEnd; Passage passage; private int primaryWS, adjWS, secWS; public WindowBuildingTVM(int primaryWindowSize, int adjacentWindowSize, int secondaryWindowSize) { this.primaryWS = primaryWindowSize; this.adjWS = adjacentWindowSize; this.secWS = secondaryWindowSize; passage = new Passage();//reuse the passage, since it will be cloned if it makes it onto the priority queue } public void map(String term, int frequency, TermVectorOffsetInfo[] offsets, int[] positions) { if (positions.length > 0 && term.startsWith(NameFilter.NE_PREFIX) == false && term.startsWith(NE_PREFIX_LOWER) == false) {//filter out the types, as we don't need them here //construct the windows, which means we need a bunch of bracketing variables to know what window we are in //start and end of the primary window int primStart = spanStart - primaryWS; int primEnd = spanEnd + primaryWS; // stores the start and end of the adjacent previous and following int adjLBStart = primStart - adjWS; int adjLBEnd = primStart - 1;//don't overlap int adjUBStart = primEnd + 1;//don't o int adjUBEnd = primEnd + adjWS; //stores the start and end of the secondary previous and the secondary following int secLBStart = adjLBStart - secWS; int secLBEnd = adjLBStart - 1; //don't overlap the adjacent window int secUBStart = adjUBEnd + 1; int secUBEnd = adjUBEnd + secWS; WindowTerm lastWT = null; for (int i = 0; i < positions.length; i++) {//unfortunately, we still have to loop over the positions //we'll make this inclusive of the boundaries, do an upfront check here so we can skip over anything that is outside of all windows if (positions[i] >= secLBStart && positions[i] <= secUBEnd) { //fill in the windows WindowTerm wt; //offsets aren't required, but they are nice to have if (offsets != null) { wt = new WindowTerm(term, positions[i], offsets[i].getStartOffset(), offsets[i].getEndOffset()); } else { wt = new WindowTerm(term, positions[i]); } if (positions[i] >= primStart && positions[i] <= primEnd) {//are we in the primary window passage.terms.add(wt); //we are only going to keep bigrams for the primary window. You could do it for the other windows, too if (lastWT != null) { WindowTerm bigramWT = new WindowTerm(lastWT.term + "," + term, lastWT.position);//we don't care about offsets for bigrams passage.bigrams.add(bigramWT); } lastWT = wt; } else if (positions[i] >= secLBStart && positions[i] <= secLBEnd) {//are we in the secondary previous window? passage.secPrevTerms.add(wt); } else if (positions[i] >= secUBStart && positions[i] <= secUBEnd) {//are we in the secondary following window? passage.secFollowTerms.add(wt); } else if (positions[i] >= adjLBStart && positions[i] <= adjLBEnd) {//are we in the adjacent previous window? passage.prevTerms.add(wt); } else if (positions[i] >= adjUBStart && positions[i] <= adjUBEnd) {//are we in the adjacent following window? passage.followTerms.add(wt); } } } } } public void setExpectations(String field, int numTerms, boolean storeOffsets, boolean storePositions) { // do nothing for this example //See also the PositionBasedTermVectorMapper. } } class WindowTerm implements Cloneable, Comparable<WindowTerm> { String term; int position; int start, end = -1; WindowTerm(String term, int position, int startOffset, int endOffset) { this.term = term; this.position = position; this.start = startOffset; this.end = endOffset; } public WindowTerm(String s, int position) { this.term = s; this.position = position; } @Override protected Object clone() throws CloneNotSupportedException { return super.clone(); } @Override public int compareTo(WindowTerm other) { int result = position - other.position; if (result == 0) { result = term.compareTo(other.term); } return result; } @Override public String toString() { return "WindowEntry{" + "term='" + term + '\'' + ", position=" + position + '}'; } } @Override public String getDescription() { return "Question Answering PassageRanking"; } @Override public String getVersion() { return "$Revision:$"; } @Override public String getSourceId() { return "$Id:$"; } @Override public String getSource() { return "$URL:$"; } }