uk.co.flax.biosolr.ontology.search.elasticsearch.ElasticDocumentSearch.java Source code

Java tutorial

Introduction

Here is the source code for uk.co.flax.biosolr.ontology.search.elasticsearch.ElasticDocumentSearch.java

Source

/**
 * Copyright (c) 2016 Lemur Consulting Ltd.
 * <p/>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p/>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p/>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package uk.co.flax.biosolr.ontology.search.elasticsearch;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.cardinality.Cardinality;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.flax.biosolr.ontology.api.Document;
import uk.co.flax.biosolr.ontology.config.ElasticSearchConfiguration;
import uk.co.flax.biosolr.ontology.search.DocumentSearch;
import uk.co.flax.biosolr.ontology.search.ResultsList;
import uk.co.flax.biosolr.ontology.search.SearchEngineException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * Created by mlp on 17/02/16.
 *
 * @author mlp
 */
public class ElasticDocumentSearch extends ElasticSearchEngine implements DocumentSearch {

    private static final Logger LOGGER = LoggerFactory.getLogger(ElasticDocumentSearch.class);

    private static final String GROUP_FIELD = "study_id";
    private static final String COUNT_AGGREGATION = "numFound";
    private static final String HITS_AGGREGATION = "study";
    private static final String SCORE_AGGREGATION = "topScore";

    private static final String[] DEFAULT_FIELDS = new String[] { "title", "first_author", "publication",
            "efo_uri.label" };
    private static final List<String> ANNOTATED_FIELDS = new ArrayList<>();
    static {
        ANNOTATED_FIELDS.add("label");
        ANNOTATED_FIELDS.add("child_labels");
        ANNOTATED_FIELDS.add("parent_labels");
    }

    public ElasticDocumentSearch(Client client, ElasticSearchConfiguration config) {
        super(client, config);
    }

    @Override
    public ResultsList<Document> searchDocuments(String term, int start, int rows, List<String> additionalFields,
            List<String> filters) throws SearchEngineException {
        // Build the query
        MultiMatchQueryBuilder qb = QueryBuilders.multiMatchQuery(term, DEFAULT_FIELDS).minimumShouldMatch("2<25%");
        if (additionalFields != null && additionalFields.size() > 0) {
            List<String> parsedAdditional = parseAdditionalFields(additionalFields);
            parsedAdditional.forEach(qb::field);
        }

        TopHitsBuilder topHitsBuilder = AggregationBuilders.topHits(HITS_AGGREGATION).setFrom(0).setSize(1);

        /* Build the terms aggregation, since we need a result set grouped by study ID.
         * The "top_score" sub-agg allows us to sort by the top score of the results;
         * the topHits sub-agg actually pulls back the record data, returning just the first
         * hit in the aggregation.
         * Note that we have to get _all_ rows up to and including the last required, annoyingly. */
        AggregationBuilder termsAgg = AggregationBuilders.terms(HITS_AGGREGATION).field(GROUP_FIELD)
                .order(Terms.Order.aggregation(SCORE_AGGREGATION, false)).size(start + rows)
                .subAggregation(AggregationBuilders.max(SCORE_AGGREGATION)
                        .script(new Script("_score", ScriptService.ScriptType.INLINE, "expression", null)))
                .subAggregation(topHitsBuilder);

        // Build the actual search request, including another aggregation to get
        // the number of unique study IDs returned.
        SearchRequestBuilder srb = getClient().prepareSearch(getIndexName()).setTypes(getDocumentType())
                .setQuery(qb).setSize(0).addAggregation(termsAgg)
                .addAggregation(AggregationBuilders.cardinality(COUNT_AGGREGATION).field(GROUP_FIELD));
        LOGGER.debug("ES Query: {}", srb.toString());

        SearchResponse response = srb.execute().actionGet();

        // Handle the response
        long total = ((Cardinality) (response.getAggregations().get(COUNT_AGGREGATION))).getValue();
        List<Document> docs;
        if (total == 0) {
            docs = new ArrayList<>();
        } else {
            // Build a map - need to look up annotation data separately.
            // This is because it's not in _source, and the fields() method
            // is not visible for a TopHitsBuilder.
            Map<String, Document> documentMap = new LinkedHashMap<>(rows);
            ObjectMapper mapper = buildObjectMapper();

            int lastIdx = (int) (start + rows <= total ? start + rows : total);
            StringTerms terms = response.getAggregations().get(HITS_AGGREGATION);
            List<Terms.Bucket> termBuckets = terms.getBuckets().subList(start, lastIdx);
            for (Terms.Bucket bucket : termBuckets) {
                TopHits hits = bucket.getAggregations().get(HITS_AGGREGATION);
                SearchHit hit = hits.getHits().getAt(0);
                documentMap.put(hit.getId(), extractDocument(mapper, hit));
            }

            // Populate annotation data for the document
            lookupAnnotationFields(documentMap);

            docs = new ArrayList<>(documentMap.values());
        }

        return new ResultsList<>(docs, start, (start / rows), total);
    }

    @Override
    public ResultsList<Document> searchByEfoUri(int start, int rows, String term, String... uris)
            throws SearchEngineException {
        return null;
    }

    private List<String> parseAdditionalFields(List<String> additional) {
        List<String> parsed;

        if (additional == null || additional.size() == 0) {
            parsed = null;
        } else {
            // Need to add annotation field name to all additional fields
            // Also need to handle hard-coded Solr field names
            parsed = additional.stream().map(add -> add.replaceAll("^efo_uri_(.*)_t$", "$1"))
                    .map(add -> getAnnotationField() + "." + add).collect(Collectors.toList());
        }

        return parsed;
    }

    private ObjectMapper buildObjectMapper() {
        ObjectMapper mapper = new ObjectMapper();
        mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        return mapper;
    }

    private Document extractDocument(ObjectMapper mapper, SearchHit hit) throws SearchEngineException {
        Document doc;

        try {
            doc = mapper.readValue(hit.getSourceAsString(), Document.class);
        } catch (IOException e) {
            LOGGER.error("Error reading document from source: {}", e.getMessage());
            throw new SearchEngineException(e);
        }

        return doc;
    }

    private void lookupAnnotationFields(Map<String, Document> idMap) {
        QueryBuilder qb = QueryBuilders.idsQuery(getDocumentType()).addIds(idMap.keySet());
        SearchRequestBuilder srb = getClient().prepareSearch(getIndexName()).addFields("*").setQuery(qb)
                .setSize(idMap.size());
        LOGGER.debug("Annotation field lookup query: {}", srb.toString());

        SearchResponse response = srb.execute().actionGet();
        for (SearchHit hit : response.getHits().getHits()) {
            populateAnnotationFields(hit, idMap.get(hit.getId()));
        }
    }

    private void populateAnnotationFields(SearchHit hit, Document doc) {
        if (doc != null && hit.fields().size() > 0) {
            for (Map.Entry<String, SearchHitField> fieldEntry : hit.fields().entrySet()) {
                if (fieldEntry.getKey().startsWith(getAnnotationField())) {
                    String fieldName = fieldEntry.getKey();

                    switch (fieldName) {
                    case "efo_uri.label":
                        doc.setEfoLabels(getStringValues(fieldEntry.getValue().getValues()));
                        break;
                    case "efo_uri.child_labels":
                        doc.setChildLabels(getStringValues(fieldEntry.getValue().getValues()));
                        break;
                    case "efo_uri.parent_labels":
                        doc.setParentLabels(getStringValues(fieldEntry.getValue().getValues()));
                        break;
                    default:
                        String shortName = fieldName.substring("efo_uri.".length());
                        if (fieldName.endsWith("_rel_uris")) {
                            doc.getRelatedIris().put(shortName, getStringValues(fieldEntry.getValue().getValues()));
                        } else if (fieldName.endsWith("_rel_labels")) {
                            List<String> labels = getStringValues(fieldEntry.getValue().getValues());
                            if (labels != null) {
                                doc.getRelatedLabels().put(shortName, labels);
                            }
                        }
                    }
                }
            }
        }
    }

    private List<String> getStringValues(List<Object> fieldValues) {
        List<String> retList;
        if (fieldValues == null || fieldValues.size() == 0) {
            retList = null;
        } else {
            retList = fieldValues.stream().map(Object::toString).collect(Collectors.toList());
        }
        return retList;
    }

}