com.ibm.watson.developer_cloud.retrieve_and_rank.v1.utils.SolrUtils.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.watson.developer_cloud.retrieve_and_rank.v1.utils.SolrUtils.java

Source

/**
 * Copyright 2015 IBM Corp. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */
package com.ibm.watson.developer_cloud.retrieve_and_rank.v1.utils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrRequest.METHOD;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.impl.HttpSolrClient;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrDocumentList;
import org.apache.solr.common.params.ModifiableSolrParams;

import com.google.gson.JsonObject;
import com.ibm.watson.developer_cloud.retrieve_and_rank.v1.model.RankResult;
import com.ibm.watson.developer_cloud.retrieve_and_rank.v1.model.Ranker;
import com.ibm.watson.developer_cloud.retrieve_and_rank.v1.model.SolrResult;
import com.ibm.watson.developer_cloud.retrieve_and_rank.v1.model.SolrResults;
import com.ibm.watson.developer_cloud.retrieve_and_rank.v1.payload.QueryRequestPayload;

/**
 * Utility class to search in for documents Solr using or not a {@link Ranker}
 */
public class SolrUtils {

    private static final String TITLE = "title";
    private static final String ID = "id";
    private static final String BODY = "body";
    private static String FCSELECT_REQUEST_HANDLER = "/fcselect";
    private static String FEATURE_VECTOR_FIELD = "featureVector";
    private static String FIELD_LIST_PARAM = "fl";
    private JsonObject groundTruth;
    private static String ID_FIELD = ID;
    private static Logger logger = Logger.getLogger(SolrUtils.class.getName());
    private static String RANKER_ID = "ranker_id";
    private static String SCORE_FIELD = "score";
    private String collectionName;
    private String rankerId;
    private HttpSolrClient solrClient;

    /**
     *
     * @param solrClient the Solr client
     * @param groundTruth the ground truth
     * @param collectionName the collection name
     * @param rankerID the ranker id
     */
    public SolrUtils(HttpSolrClient solrClient, JsonObject groundTruth, String collectionName, String rankerID) {
        this.rankerId = rankerID;
        this.collectionName = collectionName;
        this.solrClient = solrClient;
        this.groundTruth = groundTruth;
    }

    /**
     * Process a Solr request up to 3 times.
     *
     * @param request the request
     * @return the query response
     * @throws IOException Signals that an I/O exception has occurred.
     * @throws SolrServerException the Solr server exception
     * @throws InterruptedException the interrupted exception
     */
    private QueryResponse processSolrRequest(QueryRequest request)
            throws IOException, SolrServerException, InterruptedException {
        int currentAttempt = 0;
        QueryResponse response;
        while (true) {
            try {
                currentAttempt++;
                response = request.process(solrClient, collectionName);
                break;
            } catch (Exception e) {
                if (currentAttempt < 3) {
                    Thread.sleep(1000);
                } else {
                    throw e;
                }
            }
        }
        return response;
    }

    /**
     * Search Sorl using a {@link Ranker} if <code>useRanker</code> is specified
     *
     * @param query the query
     * @param useRanker the use ranker
     * @return the list
     */
    public SolrResults search(QueryRequestPayload query, boolean useRanker) {

        try {
            logger.info("Searching for document...");
            QueryResponse queryResponse = solrRuntimeQuery(query.getQuery(), useRanker);
            logger.info("Found " + queryResponse.getResults().size() + " documents!");

            List<RankResult> answerList = new ArrayList<>();
            List<String> ids = new ArrayList<>();

            int i = 0;
            Iterator<SolrDocument> it = queryResponse.getResults().iterator();
            while (it.hasNext()) {
                SolrDocument doc = it.next();
                ids.add((String) doc.getFieldValue(ID_FIELD));
                String score = String.valueOf(doc.getFieldValue(SCORE_FIELD));
                if (i++ < 3) {

                    RankResult a = new RankResult();
                    a.setAnswerId((String) doc.getFieldValue(ID_FIELD));
                    a.setScore(Float.parseFloat(score));
                    a.setFinalRank(i);

                    if (query.getId() != -1 && groundTruth != null) {
                        // If it is a canned query, get ground truth info
                        String id = String.valueOf(query.getId());
                        if (groundTruth.has(id)) {
                            JsonObject gtForQuery = groundTruth.get(id).getAsJsonObject();
                            if (gtForQuery.has(a.getAnswerId())) {
                                a.setRelevance(gtForQuery.get(a.getAnswerId()).getAsInt());
                            } else if (query.getId() != -1) {
                                a.setRelevance(0);
                            }
                        }
                    }
                    answerList.add(a);
                }

            }

            SolrResults result = new SolrResults();
            result.setNumberOfResults(queryResponse.getResults().size());
            result.setIds(ids);
            result.setResult(answerList);
            return result;
        } catch (IOException | SolrServerException | InterruptedException e) {
            logger.log(Level.SEVERE, "Error searching Solr", e);
        }
        return null;
    }

    /**
     * Create and process a Solr query
     *
     * @param query the query
     * @param featureVector the feature vector
     * @return the query response
     * @throws IOException Signals that an I/O exception has occurred.
     * @throws SolrServerException the Solr server exception
     * @throws InterruptedException the interrupted exception
     */
    private QueryResponse solrRuntimeQuery(String query, boolean featureVector)
            throws IOException, SolrServerException, InterruptedException {
        SolrQuery featureSolrQuery = new SolrQuery(query);
        if (featureVector) {
            featureSolrQuery.setRequestHandler(FCSELECT_REQUEST_HANDLER);
            // add the ranker id - this tells the plugin to re-reank the results in a single pass
            featureSolrQuery.setParam(RANKER_ID, rankerId);

        }

        // bring back the id, score, and featureVector for the feature query
        featureSolrQuery.setParam(FIELD_LIST_PARAM, ID_FIELD, SCORE_FIELD, FEATURE_VECTOR_FIELD);

        // need to ask for enough rows to ensure the correct answer is included in the resultset
        featureSolrQuery.setRows(1000);
        QueryRequest featureRequest = new QueryRequest(featureSolrQuery, METHOD.POST);

        return processSolrRequest(featureRequest);
    }

    /**
     * Gets the documents by ids.
     *
     * @param idsToRetrieve the ids of documents to retrieve
     * @return the documents
     */
    public Map<String, SolrResult> getDocumentsByIds(ArrayList<String> idsToRetrieve) {
        SolrDocumentList docs;
        Map<String, SolrResult> idsToDocs = new HashMap<>();
        try {

            docs = solrClient.getById(collectionName, idsToRetrieve, new ModifiableSolrParams());

            for (SolrDocument doc : docs) {
                SolrResult result = new SolrResult();
                result.setBody(doc.getFirstValue(BODY).toString().replaceAll("\\s+", " ").trim());
                result.setId(doc.getFirstValue(ID).toString());
                result.setTitle(doc.getFirstValue(TITLE).toString().replaceAll("\\s+", " ").trim());
                idsToDocs.put(result.getId(), result);
            }
        } catch (IOException | SolrServerException e) {
            logger.log(Level.SEVERE, "Error retrieven the Solr documents", e);
        }
        return idsToDocs;
    }

}