edu.cmu.lti.oaqa.baseqa.passage.rerank.scorers.LuceneInMemoryPassageScorer.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.lti.oaqa.baseqa.passage.rerank.scorers.LuceneInMemoryPassageScorer.java

Source

/*
 * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations
 * under the License.
 */

package edu.cmu.lti.oaqa.baseqa.passage.rerank.scorers;

import com.google.common.collect.*;
import com.google.common.io.Resources;
import edu.cmu.lti.oaqa.baseqa.learning_base.AbstractScorer;
import edu.cmu.lti.oaqa.type.kb.Concept;
import edu.cmu.lti.oaqa.type.kb.ConceptMention;
import edu.cmu.lti.oaqa.type.kb.ConceptType;
import edu.cmu.lti.oaqa.type.nlp.Token;
import edu.cmu.lti.oaqa.type.retrieval.Passage;
import edu.cmu.lti.oaqa.util.TypeUtil;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.store.RAMDirectory;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.*;

/**
 * An instance of an {@link AbstractScorer} for {@link Passage}s that scores candidate
 * {@link Passage}s, by using various different ways of constructing the queries, from only the
 * tokens to a full combination of tokens and concepts with all the synonyms, to query an
 * in-memory Lucene index that is created on the fly.
 * Each query construction method contributes to a dimension of the feature, and the retrieval
 * scores are used as the values.
 *
 * @see edu.cmu.lti.oaqa.baseqa.passage.retrieval.LuceneInMemorySentenceRetrievalExecutor
 *
 * @author <a href="mailto:ziy@cs.cmu.edu">Zi Yang</a> created on 4/5/16
 */
public class LuceneInMemoryPassageScorer extends AbstractScorer<Passage> {

    // TODO: Moved to separate files
    private final static Set<String> FORBIDDEN_CTYPES = ImmutableSet.of("lingpipe:other_name", "umls:qlco",
            "umls:qnco", "umls:ftcn", "umls:geoa");

    private int hits;

    private Set<String> stoplist;

    private Table<String, String, Double> uri2conf2score;

    private Table<String, String, Integer> uri2conf2rank;

    private Set<String> confs;

    private StandardAnalyzer analyzer;

    private QueryParser parser;

    private IndexReader reader;

    private IndexSearcher searcher;

    private static final Logger LOG = LoggerFactory.getLogger(LuceneInMemoryPassageScorer.class);

    @Override
    public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> aAdditionalParams)
            throws ResourceInitializationException {
        super.initialize(aSpecifier, aAdditionalParams);
        hits = Integer.class.cast(getParameterValue("hits"));
        // query constructor
        String stoplistPath = String.class.cast(getParameterValue("stoplist-path"));
        try {
            stoplist = Resources.readLines(getClass().getResource(stoplistPath), UTF_8).stream().map(String::trim)
                    .collect(toSet());
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }
        analyzer = new StandardAnalyzer();
        parser = new QueryParser("text", analyzer);
        return true;
    }

    private static String normalizeQuoteName(String name) {
        return "\"" + QueryParser.escape(name) + "\"";
    }

    @Override
    public void prepare(JCas jcas) throws AnalysisEngineProcessException {
        uri2conf2score = HashBasedTable.create();
        uri2conf2rank = HashBasedTable.create();
        // index
        List<Passage> passages = TypeUtil.getRankedPassages(jcas);
        RAMDirectory index = new RAMDirectory();
        try (IndexWriter writer = new IndexWriter(index, new IndexWriterConfig(analyzer))) {
            for (Passage passage : passages) {
                Document doc = new Document();
                doc.add(new StringField("uri", TypeUtil.getUriOffsets(passage, ":"), Field.Store.YES));
                doc.add(new TextField("text", passage.getText(), Field.Store.NO));
                writer.addDocument(doc);
            }
            writer.close();
            reader = DirectoryReader.open(index);
            searcher = new IndexSearcher(reader);
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
        // queries
        List<String> tokens = TypeUtil.getOrderedTokens(jcas).stream().map(Token::getCoveredText)
                .map(QueryParser::escape).filter(name -> !name.isEmpty() && !stoplist.contains(name.toLowerCase()))
                .collect(toList());
        Multimap<String, String> ctype2names = HashMultimap.create();
        for (Concept concept : TypeUtil.getConcepts(jcas)) {
            Set<String> ctypes = TypeUtil.getConceptTypes(concept).stream().map(ConceptType::getAbbreviation)
                    .collect(toSet());
            String cnames = TypeUtil.getConceptNames(concept).stream()
                    .map(LuceneInMemoryPassageScorer::normalizeQuoteName).distinct().collect(joining(" "));
            ctypes.stream().filter(t -> !FORBIDDEN_CTYPES.contains(t))
                    .forEach(ctype -> ctype2names.put(ctype, cnames));
        }
        Multimap<String, String> ctypepre2names = HashMultimap.create();
        ctype2names.asMap().entrySet().forEach(e -> ctypepre2names.putAll(e.getKey().split(":")[0], e.getValue()));
        Multimap<String, String> ctype2mentions = HashMultimap.create();
        for (Concept concept : TypeUtil.getConcepts(jcas)) {
            Set<String> ctypes = TypeUtil.getConceptTypes(concept).stream().map(ConceptType::getAbbreviation)
                    .collect(toSet());
            String cmentions = TypeUtil.getConceptMentions(concept).stream().map(ConceptMention::getMatchedName)
                    .map(LuceneInMemoryPassageScorer::normalizeQuoteName).distinct().collect(joining(" "));
            ctypes.stream().filter(t -> !FORBIDDEN_CTYPES.contains(t))
                    .forEach(ctype -> ctype2mentions.put(ctype, cmentions));
        }
        Multimap<String, String> ctypepre2mentions = HashMultimap.create();
        ctypepre2mentions.asMap().entrySet()
                .forEach(e -> ctypepre2mentions.putAll(e.getKey().split(":")[0], e.getValue()));
        LOG.debug("Query strings");
        ExecutorService service = Executors.newCachedThreadPool();
        // execute against all tokens
        service.submit(() -> {
            String concatTokens = String.join(" ", tokens);
            LOG.debug(" - Concatenated tokens: {}", concatTokens);
            search(concatTokens, "tokens_concatenated@all");
        });
        // execute against concatenated concept names
        service.submit(() -> {
            String concatCnames = String.join(" ", ctype2names.values());
            LOG.debug(" - Concatenated concept names: {}", concatCnames);
            search(concatCnames, "cnames_concatenated@all");
        });
        // execute against concatenated concept mentions
        service.submit(() -> {
            String concatCmentions = String.join(" ", ctype2mentions.values());
            LOG.debug(" - Concatenated concept mentions: {}", concatCmentions);
            search(concatCmentions, "cmentions_concatenated@all");
        });
        // execute against concept names for each concept
        service.submit(() -> {
            for (String cnames : ImmutableSet.copyOf(ctype2names.values())) {
                LOG.debug(" - Concatenated concept names: {}", cnames);
                search(cnames, "cnames_individual@all");
            }
        });
        // execute against concept names for each concept type
        service.submit(() -> {
            for (String ctype : ctype2names.keySet()) {
                String concatCnames = String.join(" ", ctype2names.get(ctype));
                LOG.debug(" - Concatenated concept names for {}: {}", ctype, concatCnames);
                search(concatCnames, "cnames@" + ctype + "@all");
            }
        });
        // execute against concept names for each concept type prefix
        service.submit(() -> {
            for (String ctypepre : ctypepre2names.keySet()) {
                String concatCnames = String.join(" ", ctypepre2names.get(ctypepre));
                LOG.debug(" - Concatenated concept names for {}: {}", ctypepre, concatCnames);
                search(concatCnames, "cnames@" + ctypepre + "@all");
            }
        });
        // execute against concept mentions for each concept
        service.submit(() -> {
            for (String cmentions : ImmutableSet.copyOf(ctype2mentions.values())) {
                LOG.debug(" - Concatenated concept mentions: {}", cmentions);
                search(cmentions, "cmentions_individual@all");
            }
        });
        // execute against concept mentions for each concept type
        service.submit(() -> {
            for (String ctype : ctype2mentions.keySet()) {
                String concatCmentions = String.join(" ", ctype2mentions.get(ctype));
                LOG.debug(" - Concatenated concept mentions for {}: {}", ctype, concatCmentions);
                search(concatCmentions, "cmentions@" + ctype + "@all");
            }
        });
        // execute against concept mentions for each concept type prefix
        service.submit(() -> {
            for (String ctypepre : ctypepre2mentions.keySet()) {
                String concatCmentions = String.join(" ", ctypepre2mentions.get(ctypepre));
                LOG.debug(" - Concatenated concept mentions for {}: {}", ctypepre, concatCmentions);
                search(concatCmentions, "cmentions@" + ctypepre + "@all");
            }
        });
        service.shutdown();
        try {
            service.awaitTermination(1, TimeUnit.MINUTES);
        } catch (InterruptedException e) {
            throw new AnalysisEngineProcessException(e);
        }
        confs = uri2conf2score.columnKeySet();
    }

    private void search(String queryString, String conf) throws RuntimeException {
        if (queryString.trim().isEmpty())
            return;
        ScoreDoc[] results;
        try {
            Query query = parser.parse(queryString);
            results = searcher.search(query, hits).scoreDocs;
        } catch (ParseException | IOException e) {
            throw new RuntimeException(e);
        }
        for (int i = 0; i < results.length; i++) {
            try {
                int doc = results[i].doc;
                String uri = reader.document(doc).get("uri");
                if (!uri2conf2rank.contains(uri, conf) || uri2conf2rank.get(uri, conf) > i) {
                    synchronizedPut(uri2conf2rank, uri, conf, i);
                }
                double score = results[i].score;
                if (!uri2conf2score.contains(uri, conf) || uri2conf2score.get(uri, conf) < score) {
                    synchronizedPut(uri2conf2score, uri, conf, score);
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private static synchronized <T extends Number> void synchronizedPut(Table<String, String, T> table, String uri,
            String conf, T value) {
        table.put(uri, conf, value);
    }

    @Override
    public Map<String, Double> score(JCas jcas, Passage result) {
        ImmutableMap.Builder<String, Double> ret = new ImmutableMap.Builder<>();
        String uri = TypeUtil.getUriOffsets(result, ":");
        for (String conf : confs) {
            double rank = uri2conf2rank.contains(uri, conf) ? 1.0 / (uri2conf2rank.get(uri, conf) + 1.0) : 0.0;
            ret.put(conf + "/rank", rank);
            double score = uri2conf2score.contains(uri, conf) ? uri2conf2score.get(uri, conf) : 0.0;
            ret.put(conf + "/score", score);
        }
        return ret.build();
    }

}