org.tallison.solr.search.concordance.KeywordCooccurRankHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.tallison.solr.search.concordance.KeywordCooccurRankHandler.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.tallison.solr.search.concordance;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.corpus.stats.IDFCalc;
import org.apache.lucene.corpus.stats.TermIDF;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.search.Filter;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.concordance.charoffsets.TargetTokenNotFoundException;
import org.apache.lucene.search.concordance.classic.ConcordanceSortOrder;
import org.apache.lucene.search.concordance.classic.DocIdBuilder;
import org.apache.lucene.search.concordance.classic.DocMetadataExtractor;
import org.apache.lucene.search.concordance.classic.impl.FieldBasedDocIdBuilder;
import org.apache.lucene.search.concordance.classic.impl.SimpleDocMetadataExtractor;
import org.apache.solr.cloud.RequestThreads;
import org.apache.solr.cloud.RequestWorker;
import org.apache.solr.cloud.ZkController;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.QParser;
import org.apache.solr.search.SolrIndexSearcher;
import org.tallison.lucene.search.concordance.windowvisitor.ConcordanceArrayWindowSearcher;
import org.tallison.lucene.search.concordance.windowvisitor.CooccurVisitor;
import org.tallison.lucene.search.concordance.windowvisitor.Grammer;
import org.tallison.lucene.search.concordance.windowvisitor.WGrammer;

/**
 * <requestHandler name="/kwCooccur" class="org.apache.solr.search.concordance.KeywordCooccurRankHandler">
 * <lst name="defaults">
 * <str name="echoParams">explicit</str>
 * <str name="defType">spanquery</str>
 * <str name="f">content_txt</str>
 * <str name="df">content_txt</str>
 * <str name="wt">xml</str>
 * <str name="minNGram">1</str>
 * <str name="maxNGram">2</str>
 * <str name="minTF">3</str>
 * <str name="numResults">50</str>
 * <p>
 * <!--  More fields
 * <str name="maxWindows">500</str>
 * <str name="debug">false</str>
 * <str name="fl">metadata field1,metadata field2,metadata field3</str>
 * <str name="targetOverlaps">true</str>
 * <str name="contentDisplaySize">42</str>
 * <str name="targetDisplaySize">42</str>
 * <str name="tokensAfter">42</str>
 * <str name="tokensBefore">42</str>
 * <str name="sortOrder">TARGET_PRE</str>
 * -->
 * </lst>
 * <lst name="invariants">
 * <str name="prop1">value1</str>
 * <int name="prop2">2</int>
 * <!-- ... more config items here ... -->
 * </lst>
 * </requestHandler>
 *
 * @author JRROBINSON
 */

public class KeywordCooccurRankHandler extends SolrConcordanceBase {

    public static final String DefaultName = "/kwCo";

    public static final String NODE = "contextKeywords";
    /**
     * Max number of request threads to spawn.  Since this service wasn't intended to return
     * ALL possible results, it seems reasonable to cap this at something
     */
    public final static int MAX_THREADS = 25;

    ;

    static public RequestThreads<CooccurConfig> initRequestPump(List<String> shards, SolrQueryRequest req) {
        return initRequestPump(shards, req, MAX_THREADS);
    }

    static public RequestThreads<CooccurConfig> initRequestPump(List<String> shards, SolrQueryRequest req,
            int maxThreads) {

        SolrParams params = req.getParams();
        String field = SolrConcordanceBase.getField(params, req.getSchema().getDefaultSearchFieldName());
        String q = params.get(CommonParams.Q);
        CooccurConfig config = configureParams(field, params);

        /**/
        RequestThreads<CooccurConfig> threads = RequestThreads
                .<CooccurConfig>newFixedThreadPool(Math.min(shards.size(), maxThreads)).setMetadata(config);

        String handler = getHandlerName(req, DefaultName, KeywordCooccurRankHandler.class);
        int partial = Math.round(config.getMaxWindows() / (float) shards.size());

        ModifiableSolrParams p = getWorkerParams(field, q, params, partial);

        int i = 0;
        for (String node : shards) {
            if (i++ > maxThreads)
                break;

            //could be https, no?
            String url = "http://" + node;

            RequestWorker worker = new RequestWorker(url, handler, p).setName(node);
            threads.addExecute(worker);
        }
        threads.seal(); //disallow future requests (& execute

        return threads;
    }

    public static NamedList doLocalSearch(SolrQueryRequest req) throws Exception {
        return doLocalSearch(null, req);
    }

    //xx
    public static NamedList doLocalSearch(Query filter, SolrQueryRequest req) throws Exception {
        SolrParams params = req.getParams();
        String field = getField(params);

        String fl = params.get(CommonParams.FL);
        DocMetadataExtractor metadataExtractor = (fl != null && fl.length() > 0)
                ? new SimpleDocMetadataExtractor(fl.split(","))
                : new SimpleDocMetadataExtractor();

        CooccurConfig config = configureParams(field, params);

        IndexSchema schema = req.getSchema();
        SchemaField sf = schema.getField(field);
        Analyzer analyzer = sf.getType().getIndexAnalyzer();
        Filter queryFilter = getFilterQuery(req);
        String q = params.get(CommonParams.Q);
        Query query = QParser.getParser(q, null, req).parse();
        String solrUniqueKeyField = req.getSchema().getUniqueKeyField().getName();

        SolrIndexSearcher solr = req.getSearcher();
        IndexReader reader = solr.getIndexReader();
        boolean allowDuplicates = false;
        boolean allowFieldSeparators = false;

        Grammer grammer = new WGrammer(config.getMinNGram(), config.getMaxNGram(), allowFieldSeparators);
        IDFCalc idfCalc = new IDFCalc(reader);

        CooccurVisitor visitor = new CooccurVisitor(field, config.getTokensBefore(), config.getTokensAfter(),
                grammer, idfCalc, config.getMaxWindows(), allowDuplicates);

        visitor.setMinTermFreq(config.getMinTermFreq());

        try {
            ConcordanceArrayWindowSearcher searcher = new ConcordanceArrayWindowSearcher();
            System.out.println("UNIQUE KEY FIELD: " + solrUniqueKeyField);
            DocIdBuilder docIdBuilder = new FieldBasedDocIdBuilder(solrUniqueKeyField);
            System.out.println("QUERY: " + query.toString());
            searcher.search(reader, field, query, queryFilter, analyzer, visitor, docIdBuilder);
        } catch (IllegalArgumentException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (TargetTokenNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        List<TermIDF> overallResults = visitor.getResults();
        NamedList results = toNamedList(overallResults);
        //needed for cloud computations, merging cores

        results.add("collectionSize", reader.numDocs());
        results.add("numDocsVisited", visitor.getNumDocsVisited());
        results.add("numWindowsVisited", visitor.getNumWindowsVisited());
        results.add("numResults", overallResults.size());
        results.add("minTF", visitor.getMinTermFreq());

        return results;
    }

    public static ModifiableSolrParams getWorkerParams(String field, String q, SolrParams parent,
            Integer maxWindows) {
        ModifiableSolrParams params = new ModifiableSolrParams();

        params.set("f", field);
        params.set("q", q);
        params.set("maxWindows", maxWindows);
        params.set("lq", true); //flag to disallow recursive zoo queries
        params.set("rows", 0);
        setParam("fq", params, parent);
        setParam("anType", params, parent);
        setParam("numResults", params, parent);
        setParam("minNGram", params, parent);
        setParam("maxNGram", params, parent);
        setParam("minTF", params, parent);
        setParam("minDF", params, parent);

        setParam("echoParams", params, parent);
        setParam("defType", params, parent);
        setParam("wt", params, parent);
        setParam("debug", params, parent);
        setParam("fl", params, parent);
        setParam("targetOverlaps", params, parent);
        setParam("contentDisplaySize", params, parent);
        setParam("targetDisplaySize", params, parent);
        setParam("tokensAfter", params, parent);
        setParam("tokensBefore", params, parent);
        setParam("sortOrder", params, parent);

        return params;
    }

    public static Results spinWait(RequestThreads<CooccurConfig> threads) {
        Results results = new Results(threads.getMetadata());
        return spinWait(threads, results);
    }

    public static Results spinWait(RequestThreads<CooccurConfig> threads, Results results) {
        if (threads == null || threads.empty())
            return results;

        while (!threads.isTerminated() && !threads.empty() && !results.hitMax) {
            RequestWorker req = threads.next();
            if (!req.isRunning()) {
                NamedList nl = req.getResults();
                if (nl != null) {
                    results.add(nl, req.getName());
                }
                threads.removeLast();
            }
        }

        //force complete shutdown
        threads.shutdownNow();

        //if not enough hits, check any remaining threads that haven't been collected
        //for(RequestWorker req : otherRequests)
        while (!threads.empty() && !results.hitMax) {

            RequestWorker req = threads.next();

            if (req != null && !req.isRunning()) {
                NamedList nl = req.getResults();
                if (nl != null) {
                    results.add(nl, req.getName());
                }

                threads.removeLast();
            }
        }

        threads.clear();
        threads = null;

        return results;
    }

    public static CooccurConfig configureParams(String field, SolrParams params) {
        CooccurConfig config = new CooccurConfig(field);
        String param = params.get("targetOverlaps");
        if (param != null && param.length() > 0) {
            try {
                config.setAllowTargetOverlaps(Boolean.parseBoolean(param));
            } catch (Exception e) {
            }
        }

        param = params.get("contentDisplaySize");
        if (param != null && param.length() > 0) {
            try {
                config.setMaxContextDisplaySizeChars(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("targetDisplaySize");
        if (param != null && param.length() > 0) {
            try {
                config.setMaxTargetDisplaySizeChars(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("maxWindows");
        if (param != null && param.length() > 0) {
            try {
                config.setMaxWindows(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("tokensAfter");
        if (param != null && param.length() > 0) {
            try {
                config.setTokensAfter(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("tokensBefore");
        if (param != null && param.length() > 0) {
            try {
                config.setTokensBefore(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("sortOrder");
        if (param != null && param.length() > 0) {
            try {
                config.setSortOrder(ConcordanceSortOrder.valueOf(param));
            } catch (Exception e) {
            }
        }

        param = params.get("minNGram");
        if (param != null && param.length() > 0) {
            try {
                config.setMinNGram(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("maxNGram");
        if (param != null && param.length() > 0) {
            try {
                config.setMaxNGram(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("minTF");
        if (param != null && param.length() > 0) {
            try {
                config.setMinTermFreq(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("tokensBefore");
        if (param != null && param.length() > 0) {
            try {
                config.setTokensBefore(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("tokensAfter");
        if (param != null && param.length() > 0) {
            try {
                config.setTokensAfter(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        param = params.get("numResults");
        if (param != null && param.length() > 0) {
            try {
                config.setNumResults(Integer.parseInt(param));
            } catch (Exception e) {
            }
        }

        return config;
    }

    @SuppressWarnings({ "rawtypes", "unchecked" })
    public static NamedList toNamedList(List<TermIDF> results) {
        SimpleOrderedMap ret = new SimpleOrderedMap();

        if (results.size() > 0) {
            NamedList nlResults = new SimpleOrderedMap();
            ret.add("results", nlResults);

            for (TermIDF result : results) {
                NamedList nl = new SimpleOrderedMap();
                nl.add("term", result.getTerm());
                //nl.add("value", result.getValue());
                nl.add("tfidf", result.getTFIDF());
                nl.add("tf", result.getTermFreq());
                nl.add("idf", result.getIDF());

                nl.add("df", result.getDocFreq());

                nlResults.add("result", nl);
            }
        }

        return ret;
    }

    /*
        
        public void search(IndexReader reader, String fieldName,
          Query query, Filter filter, Analyzer analyzer,
          ArrayWindowVisitor visitor, DocIdBuilder docIdBuilder ) throws IllegalArgumentException
        {
        
    try {
        ConcordanceArrayWindowSearcher searcher = new ConcordanceArrayWindowSearcher();
     searcher.search(reader, fieldName, query, filter, analyzer, visitor, docIdBuilder );
          } catch (IllegalArgumentException e) {
     e.printStackTrace();
          } catch (TargetTokenNotFoundException e) {
        e.printStackTrace();
    } catch (IOException e) {
        e.printStackTrace();
    }
        
    //xx copy consturctor instead?
    CooccurVisitor covisitor = (CooccurVisitor)visitor;
    List<TermIDF>  overallResults = covisitor.getResults();
          NamedList results = toNamedList(overallResults);
    results.add("collectionSize", reader.numDocs());
    results.add("numDocsVisited", covisitor.getNumDocsVisited());
    results.add("numWindowsVisited", covisitor.getNumWindowsVisited());
    results.add("numResults", overallResults.size());
    results.add("minTF", covisitor.getMinTermFreq());
        
    //TODO: convert results to docIdBuilder? xx
          //xx return results;
       }
        
    */

    public static String getField(SolrParams params) {
        String fieldName = params.get(CommonParams.FIELD);
        if (fieldName == null || fieldName.equalsIgnoreCase("null")) {

            if (fieldName == null || fieldName.equalsIgnoreCase("null"))
                fieldName = params.get(CommonParams.DF);

            if (fieldName == null || fieldName.equalsIgnoreCase("null")) {
                //check field list if not in field
                fieldName = params.get(CommonParams.FL);

                //TODO: change when/if request allows for multiple terms
                if (fieldName != null)
                    fieldName = fieldName.split(",")[0].trim();
            }
        }
        return fieldName;
    }

    @Override
    public void init(@SuppressWarnings("rawtypes") NamedList args) {
        super.init(args);
        // this.prop1 = invariants.get("prop1");
    }

    @Override
    public String getDescription() {
        return "Returns tokens that frequently co-occur within concordance windows";
    }

    @Override
    public String getSource() {
        return "https://issues.apache.org/jira/browse/SOLR-5411 - https://github.com/tballison/lucene-addons";
    }

    @SuppressWarnings("unchecked")
    @Override
    public void handleRequestBody(SolrQueryRequest req, SolrQueryResponse rsp) throws Exception {
        boolean isDistrib = isDistributed(req);
        if (isDistrib) {
            System.out.println("DOING ZOO QUERY");
            doZooQuery(req, rsp);
        } else {
            doQuery(req, rsp);
        }
    }

    @SuppressWarnings("unchecked")
    private void doZooQuery(SolrQueryRequest req, SolrQueryResponse rsp) throws Exception {
        SolrParams params = req.getParams();
        String field = getField(params);
        CooccurConfig config = configureParams(field, params);

        boolean debug = params.getBool("debug", false);
        NamedList nlDebug = new SimpleOrderedMap();

        if (debug)
            rsp.add("DEBUG", nlDebug);

        ZkController zoo = req.getCore().getCoreDescriptor().getCoreContainer().getZkController();
        Set<String> nodes = zoo.getClusterState().getLiveNodes();

        List<String> shards = new ArrayList<String>(nodes.size());
        String thisUrl = req.getCore().getCoreDescriptor().getCoreContainer().getZkController().getBaseUrl();

        for (String node : nodes) {
            String shard = node.replace("_", "/");
            if (thisUrl.contains(shard))
                continue;

            shard += "/" + req.getCore().getName();

            shards.add(shard);
        }
        System.out.println("SHARDS SIZE: " + shards.size());
        RequestThreads<CooccurConfig> threads = initRequestPump(shards, req);

        Results results = new Results(threads.getMetadata());

        /*  //skip local
            NamedList nl = doLocalSearch(req);
            for (int i = 0; i < nl.size(); i++) {
              System.out.println("RETURNED FROM SERVER: " + "LOCAL" + " : " + nl.getName(i) + " ; " + nl.getVal(i));
            }
            
            results.add(nl, "local");*/

        results = spinWait(threads, results);

        rsp.add(NODE, results.toNamedList());

    }

    ;

    private void doQuery(SolrQueryRequest req, SolrQueryResponse rsp)
            throws Exception, IllegalArgumentException, ParseException, TargetTokenNotFoundException {
        NamedList results = doLocalSearch(req);
        rsp.add(NODE, results);
    }

    ;

    @Override
    protected String getHandlerName(SolrQueryRequest req) {
        return getHandlerName(req, DefaultName, this.getClass());
    }

    public static class Results {
        long maxWindows = -1;
        int maxResults = -1;
        boolean hitMax = false;
        boolean maxTerms = false;
        int size = 0;
        long numDocs = 0;
        long numWindows = 0;
        int numResults = 0;
        HashMap<String, Keyword> keywords = new HashMap<String, Keyword>();

        Results(CooccurConfig config) {
            this.maxWindows = config.getMaxWindows();
            this.maxResults = config.getNumResults();
        }

        Results(int maxWindows, int maxResults) {
            this.maxWindows = maxWindows;
            this.maxResults = maxResults;
        }

        void add(NamedList nl, String extra) {
            NamedList nlRS = (NamedList) nl.get(NODE);

            if (nlRS == null)
                nlRS = nl;

            numDocs += getInt("numDocs", nlRS);
            size += getInt("collectionSize", nlRS);
            numResults += getInt("numResults", nlRS);
            numWindows += getLong("numWindows", nlRS);

            hitMax = numWindows >= maxWindows;
            maxTerms = numResults >= maxResults;

            Object o = nlRS.get("results");
            if (o != null) {
                NamedList nlRes = (NamedList) o;

                List<NamedList> res = nlRes.getAll("result");

                for (NamedList nlTerm : res) {
                    Keyword tmp = new Keyword(nlTerm);

                    Keyword kw = keywords.get(tmp.term);

                    if (kw == null)
                        keywords.put(tmp.term, tmp);
                    else {
                        kw.tf += tmp.tf;
                        kw.df += tmp.df;
                        kw.minDF += tmp.minDF;
                    }
                }
            }
        }

        NamedList toNamedList() {
            NamedList nl = new SimpleOrderedMap<>();
            nl.add("hitMax", hitMax);
            nl.add("maxTerms", maxTerms);
            nl.add("numDocs", numDocs);
            nl.add("collectionSize", size);
            nl.add("numWindows", numWindows);
            nl.add("numResults", numResults);

            if (keywords.size() > 0) {

                //sort by new tf-idf's
                Integer[] idxs = new Integer[keywords.size()];
                final double[] tfidfs = new double[keywords.size()];
                final Keyword[] terms = new Keyword[keywords.size()];

                int i = 0;
                for (Entry<String, Keyword> kv : keywords.entrySet()) {
                    idxs[i] = i;
                    Keyword kw = kv.getValue();
                    terms[i] = kw;

                    tfidfs[i] = kw.tf * Math.log(size / kw.df);
                    i++;
                }

                //System.out.println(terms);
                //System.out.println(Arrays.toString(terms));
                //System.out.println(tfidfs);
                //System.out.println(Arrays.toString(tfidfs));

                Arrays.sort(idxs, new Comparator<Integer>() {
                    public int compare(Integer a, Integer b) {
                        int ret = Double.compare(tfidfs[b], tfidfs[a]);
                        if (ret == 0)
                            ret = Double.compare(terms[b].df, terms[a].df);
                        return ret;

                    }
                });

                NamedList<NamedList> nlResults = new SimpleOrderedMap<NamedList>();
                for (i = 0; i < idxs.length && i < maxResults; i++) {
                    Keyword kw = terms[idxs[i]];
                    NamedList nlKw = new SimpleOrderedMap<Object>();

                    nlKw.add("term", kw.term);
                    nlKw.add("tfidf", tfidfs[idxs[i]]);
                    nlKw.add("orig_tfidf", kw.tfidf);
                    nlKw.add("tf", kw.tf);
                    nlKw.add("df", kw.df);
                    nlKw.add("minDF", kw.minDF);

                    nlResults.add("result", nlKw);
                }
                nl.add("results", nlResults);
            }
            return nl;
        }
    }

    static class Keyword {
        String term;
        double tfidf = 0;
        long tf = 0;
        long df = 0;
        int minDF = 0;

        Keyword(NamedList nl) {
            term = nl.get("term").toString();
            tf = getInt("tf", nl);
            df = getInt("df", nl);
            minDF = getInt("minDF", nl);
            tfidf = getDouble("tfidf", nl);
        }

        @Override
        public String toString() {
            return term;
        }

    }

}