Java tutorial
/* * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations * under the License. */ package edu.cmu.lti.oaqa.baseqa.passage.retrieval; import com.aliasi.sentences.IndoEuropeanSentenceModel; import com.aliasi.sentences.SentenceChunker; import com.aliasi.sentences.SentenceModel; import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory; import com.aliasi.tokenizer.TokenizerFactory; import edu.cmu.lti.oaqa.baseqa.providers.parser.ParserProvider; import edu.cmu.lti.oaqa.baseqa.providers.query.BagOfPhraseQueryStringConstructor; import edu.cmu.lti.oaqa.baseqa.providers.query.QueryStringConstructor; import edu.cmu.lti.oaqa.baseqa.passage.RetrievalUtil; import edu.cmu.lti.oaqa.baseqa.util.ProviderCache; import edu.cmu.lti.oaqa.baseqa.util.UimaContextHelper; import edu.cmu.lti.oaqa.type.nlp.Token; import edu.cmu.lti.oaqa.type.retrieval.AbstractQuery; import edu.cmu.lti.oaqa.type.retrieval.Passage; import edu.cmu.lti.oaqa.util.TypeUtil; import edu.stanford.nlp.process.Morphology; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.queryparser.classic.ParseException; import org.apache.lucene.queryparser.classic.QueryParser; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.store.RAMDirectory; import org.apache.uima.UIMAException; import org.apache.uima.UimaContext; import org.apache.uima.analysis_component.JCasAnnotator_ImplBase; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.factory.JCasFactory; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.*; import java.util.regex.Pattern; import static java.util.stream.Collectors.toList; /** * An improved version of {@link LuceneInMemorySentenceRetrievalExecutor} that is used in BioASQ 3B. * * @see LuceneInMemorySentenceRetrievalExecutor * * @author <a href="mailto:xiangyus@andrew.cmu.edu">Xiangyu Sun</a> created on 10/23/14 */ public class ImprovedLuceneInMemorySentenceRetrievalExecutor extends JCasAnnotator_ImplBase { private Analyzer analyzer; private int hits; private QueryParser parser; private SentenceChunker chunker; private QueryStringConstructor queryStringConstructor; private ParserProvider parserProvider; private StanfordLemmatizer lemma; //private static GoldQuestions questions; //private static HashMap<String, HashSet<Snippet>> gold; private static final Logger LOG = LoggerFactory .getLogger(ImprovedLuceneInMemorySentenceRetrievalExecutor.class); @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); TokenizerFactory tokenizerFactory = UimaContextHelper.createObjectFromConfigParameter(context, "tokenizer-factory", "tokenizer-factory-params", IndoEuropeanTokenizerFactory.class, TokenizerFactory.class); SentenceModel sentenceModel = UimaContextHelper.createObjectFromConfigParameter(context, "sentence-model", "sentence-model-params", IndoEuropeanSentenceModel.class, SentenceModel.class); chunker = new SentenceChunker(tokenizerFactory, sentenceModel); // initialize hits hits = UimaContextHelper.getConfigParameterIntValue(context, "hits", 200); // initialize query analyzer, index writer config, and query parser analyzer = UimaContextHelper.createObjectFromConfigParameter(context, "query-analyzer", "query-analyzer-params", StandardAnalyzer.class, Analyzer.class); parser = new QueryParser("text", analyzer); // initialize query string constructor queryStringConstructor = UimaContextHelper.createObjectFromConfigParameter(context, "query-string-constructor", "query-string-constructor-params", BagOfPhraseQueryStringConstructor.class, QueryStringConstructor.class); String parserProviderName = UimaContextHelper.getConfigParameterStringValue(context, "parser-provider"); parserProvider = ProviderCache.getProvider(parserProviderName, ParserProvider.class); lemma = new StanfordLemmatizer(); } @Override public void process(JCas jcas) throws AnalysisEngineProcessException { // create lucene documents for all sentences in all sections and delete the duplicate ones Map<Integer, Passage> hash2passage = new HashMap<Integer, Passage>(); for (Passage d : TypeUtil.getRankedPassages(jcas)) { for (Passage s : RetrievalUtil.extractSentences(jcas, d, chunker)) { if (!hash2passage.containsKey(TypeUtil.hash(s))) { hash2passage.put(TypeUtil.hash(s), s); } } } // remove the documents from pipeline TypeUtil.getRankedPassages(jcas).forEach(Passage::removeFromIndexes); List<Document> luceneDocs = hash2passage.values().stream().map(RetrievalUtil::createLuceneDocument) .collect(toList()); // create lucene index RAMDirectory index = new RAMDirectory(); try (IndexWriter writer = new IndexWriter(index, new IndexWriterConfig(analyzer))) { writer.addDocuments(luceneDocs); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } // search in the index AbstractQuery aquery = TypeUtil.getAbstractQueries(jcas).stream().findFirst().get(); Map<Integer, Float> hash2score = new HashMap<>(); try (IndexReader reader = DirectoryReader.open(index)) { IndexSearcher searcher = new IndexSearcher(reader); String queryString = queryStringConstructor.construct(aquery).replace("\"", " ").replace("/", " ") .replace("[", " ").replace("]", " "); LOG.info("Search for query: {}", queryString); // construct the query Query query = parser.parse(queryString); LOG.trace(query.toString()); searcher.setSimilarity(new BM25Similarity()); ScoreDoc[] scoreDocs = searcher.search(query, hits).scoreDocs; for (ScoreDoc scoreDoc : scoreDocs) { float score = scoreDoc.score; int hash; hash = Integer.parseInt(searcher.doc(scoreDoc.doc).get("hash")); hash2score.put(hash, score); } } catch (IOException | ParseException e) { throw new AnalysisEngineProcessException(e); } LOG.info("The size of Returned Sentences: {}", hash2score.size()); // add to CAS hash2score.entrySet().stream().map(entry -> { Passage passage = hash2passage.get(entry.getKey()); passage.setScore(entry.getValue()); return passage; }).sorted(Comparator.comparing(Passage::getScore).reversed()).forEach(Passage::addToIndexes); Collection<Passage> snippets = TypeUtil.getRankedPassages(jcas); // rank the snippet and add them to pipeline rankSnippets(jcas, calSkip(jcas, hash2passage), calBM25(jcas, hash2passage), calAlignment(jcas, hash2passage), calSentenceLength(hash2passage), hash2passage); } /* * Combine all the evidence of snippet and rank them * */ private void rankSnippets(JCas jcas, Map<Integer, Float> skip_bigram, Map<Integer, Float> bm25, Map<Integer, Float> alignment, Map<Integer, Float> length, Map<Integer, Passage> hash2passage) throws AnalysisEngineProcessException { HashMap<Integer, Float> hash2score = new HashMap<Integer, Float>(); double[] params = { -3, -3436.8, -0.2, 0, 0.3 }; for (Integer it : hash2passage.keySet()) { double wT = skip_bigram.get(it) * params[0] + alignment.get(it) * params[1] + length.get(it) * params[2] + (bm25.get(it) == null ? 0 : bm25.get(it)) * params[3] + params[4]; hash2score.put(it, (float) Math.exp(wT) / (float) (1 + Math.exp(wT))); } hash2score.entrySet().stream().map(entry -> { Passage passage = hash2passage.get(entry.getKey()); passage.setScore(entry.getValue()); return passage; }).sorted(Comparator.comparing(Passage::getScore).reversed()).forEach(Passage::addToIndexes); } /* * Use dependency relations to calculate skip-bigram score * */ private Map<Integer, Float> calSkip(JCas jcas, Map<Integer, Passage> hash2passage) throws AnalysisEngineProcessException { HashMap<Integer, Float> skip_bigram = new HashMap<Integer, Float>(); String question = TypeUtil.getQuestion(jcas).getText(); // question sentence analysis HashMap<String, String> questionTokens = sentenceAnalysis(question); for (Map.Entry<Integer, Passage> iter : hash2passage.entrySet()) { String text = iter.getValue().getText(); HashMap<String, String> snippetTokens = sentenceAnalysis(text); int count = 0; for (String child : snippetTokens.keySet()) { if (questionTokens.containsKey(child) && questionTokens.get(child) == snippetTokens.get(child)) count++; } float scoreP = (float) count / (float) snippetTokens.size(); float scoreQ = (float) count / (float) questionTokens.size(); float score = scoreP * scoreQ / (scoreP + scoreQ); if (count == 0) score = 0; skip_bigram.put(iter.getKey(), score); } return skip_bigram; } /* * Dynamic programming to cal the algiment score * */ private Map<Integer, Float> calAlignment(JCas jcas, Map<Integer, Passage> hash2passage) throws AnalysisEngineProcessException { HashMap<Integer, Float> alignment = new HashMap<Integer, Float>(); String question = TypeUtil.getQuestion(jcas).getText(); String[] questionTokens = lemma.stemText(question).split(" "); for (Integer it : hash2passage.keySet()) { String[] text = hash2passage.get(it).getText().split(" "); int[][] score = new int[text.length][questionTokens.length]; // initate score for (int i = 0; i < text.length; i++) { if (text[i].equals(questionTokens[0])) score[i][0] = 1; } for (int i = 0; i < questionTokens.length; i++) { if (text[0].equals(questionTokens[i])) score[0][i] = 1; } // start calculating for (int i = 1; i < text.length; i++) { for (int j = 1; j < questionTokens.length; j++) { if (text[i].equals(questionTokens[j])) score[i][j] = Integer.max(score[i][j], score[i - 1][j - 1] + 1); else score[i][j] = Integer.max(score[i - 1][j], score[i][j - 1]); } } alignment.put(it, (float) score[text.length - 1][questionTokens.length - 1]); } return alignment; } private Map<Integer, Float> calBM25(JCas jcas, Map<Integer, Passage> hash2passage) throws AnalysisEngineProcessException { // index the documents using lucene List<Document> luceneDocs = hash2passage.values().stream().map(RetrievalUtil::createLuceneDocument) .collect(toList()); // create lucene index RAMDirectory index = new RAMDirectory(); try (IndexWriter writer = new IndexWriter(index, new IndexWriterConfig(analyzer))) { writer.addDocuments(luceneDocs); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } // search in the index AbstractQuery aquery = TypeUtil.getAbstractQueries(jcas).stream().findFirst().get(); Map<Integer, Float> hash2score = new HashMap<>(); try (IndexReader reader = DirectoryReader.open(index)) { IndexSearcher searcher = new IndexSearcher(reader); String queryString = queryStringConstructor.construct(aquery).replace("\"", " ").replace("/", " ") .replace("[", " ").replace("]", " "); LOG.info("Search for query: {}", queryString); // construct the query Query query = parser.parse(queryString); searcher.setSimilarity(new BM25Similarity()); ScoreDoc[] scoreDocs = searcher.search(query, hits).scoreDocs; for (ScoreDoc scoreDoc : scoreDocs) { float score = scoreDoc.score; int hash; hash = Integer.parseInt(searcher.doc(scoreDoc.doc).get("hash")); hash2score.put(hash, score); } } catch (IOException | ParseException e) { throw new AnalysisEngineProcessException(e); } return hash2score; } /* * Dependency Analysis for all the snippets and questions * */ private HashMap<String, String> sentenceAnalysis(String sentence) { HashMap<String, String> dependency = new HashMap<String, String>(); try { JCas snippetJcas = JCasFactory.createJCas(); snippetJcas.setDocumentText(sentence); List<Token> tokens = parserProvider.parseDependency(snippetJcas); for (Token tok : tokens) { if (tok.getHead() == null) continue; dependency.put(tok.getLemmaForm(), tok.getHead().getLemmaForm()); } snippetJcas.release(); } catch (UIMAException err) { err.printStackTrace(); } return dependency; } /* * calculating the length of all the snippet * */ private HashMap<Integer, Float> calSentenceLength(Map<Integer, Passage> hash2passage) { HashMap<Integer, Float> ret = new HashMap<Integer, Float>(); for (Integer it : hash2passage.keySet()) { ret.put(it, (float) hash2passage.get(it).getText().length()); } return ret; } } class StanfordLemmatizer { private static Morphology morph = new Morphology(); private static final Pattern p = Pattern.compile("[^a-z0-9 ]", Pattern.CASE_INSENSITIVE); public static int MAX_WORD_LEN = 128; public static String stemWord(String w) { String t = null; try { if (w.length() <= MAX_WORD_LEN) t = morph.stem(w); } catch (StackOverflowError e) { /* * TODO should we ignore stack overflow here? * so far it happens only for very long * tokens, but how knows, there might * be some other reasons as well. In that, * if stemming failed, we can simply * return the origina, unmodified, string. */ e.printStackTrace(); System.err.println("Stack overflow for string: '" + w + "'"); System.exit(1); } return t != null ? t : ""; } public static String lemma(String w, String tag) { return morph.lemma(w, tag); } /** * Split the text into token (assuming tokens are separated by whitespaces), * then stem each token separately. * */ public static String stemText(String text) { if (text == null || "".equals(text)) return text; text = text.replaceAll("[-+.^:,?]", ""); StringBuilder sb = new StringBuilder(); for (String s : text.split("\\s+")) { if ((p.matcher(s).find())) continue; sb.append(stemWord(s)); sb.append(' '); } return sb.toString(); } }