org.elasticsearch.action.termwalker.TransportTermwalkerAction.java Source code

Java tutorial

Introduction

Here is the source code for org.elasticsearch.action.termwalker.TransportTermwalkerAction.java

Source

/*
 * Licensed to ElasticSearch and Shay Banon under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. ElasticSearch licenses this
 * file to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.elasticsearch.action.termwalker;

import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.MultiFields;
import org.elasticsearch.ElasticSearchException;
import org.elasticsearch.action.ShardOperationFailedException;
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
import org.elasticsearch.action.support.broadcast.BroadcastShardOperationFailedException;
import org.elasticsearch.action.support.broadcast.TransportBroadcastOperationAction;
import org.elasticsearch.cluster.ClusterService;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.collect.ImmutableMap;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.joda.FormatDateTimeFormatter;
import org.elasticsearch.common.joda.Joda;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.service.IndexService;
import org.elasticsearch.index.shard.service.InternalIndexShard;
import org.elasticsearch.index.store.Store;
import org.elasticsearch.index.store.StoreFileMetaData;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReferenceArray;

import static org.elasticsearch.common.collect.Lists.newArrayList;

/**
 * Transport action for Termwalker plugin.
 */
public class TransportTermwalkerAction extends
        TransportBroadcastOperationAction<TermwalkerRequest, TermwalkerResponse, ShardTermwalkerRequest, ShardTermwalkerResponse> {

    private static final String DEFAULT_DATE_TIME_FORMAT = "dateOptionalTime";
    private static final FormatDateTimeFormatter DATE_TIME_FORMATTER = Joda.forPattern(DEFAULT_DATE_TIME_FORMAT);
    private final IndicesService indicesService;
    private final Object mutex = new Object();

    @Inject
    public TransportTermwalkerAction(Settings settings, ThreadPool threadPool, ClusterService clusterService,
            TransportService transportService, IndicesService indicesService) {
        super(settings, threadPool, clusterService, transportService);
        this.indicesService = indicesService;
    }

    @Override
    protected String executor() {
        return ThreadPool.Names.MERGE;
    }

    @Override
    protected String transportAction() {
        return TermwalkerAction.NAME;
    }

    @Override
    protected TermwalkerRequest newRequest() {
        return new TermwalkerRequest();
    }

    @Override
    protected boolean ignoreNonActiveExceptions() {
        return true;
    }

    @Override
    protected TermwalkerResponse newResponse(TermwalkerRequest request, AtomicReferenceArray shardsResponses,
            ClusterState clusterState) {
        int successfulShards = 0;
        int failedShards = 0;
        List<ShardOperationFailedException> shardFailures = null;
        // Map<String, Map<String, Map<String, Object>>> response = new HashMap();
        Map response = new HashMap();
        Map<String, Map<String, Integer>> dfAccumulator = new HashMap();
        Map<String, Map<String, Long>> ttfAccumulator = new HashMap();

        Integer termLimit = request.termLimit();
        Boolean includeDF = request.includeDF();
        Boolean includeTTF = request.includeTTF();

        logger.info("termwalker:" + " df: " + includeDF + " ttf: " + includeTTF + " max_doc_percent: "
                + request.maxDocFrequency() + " min_doc_percent: " + request.minDocFrequency() + " limit: "
                + termLimit);

        for (int i = 0; i < shardsResponses.length(); i++) {
            Object shardResponse = shardsResponses.get(i);
            if (shardResponse == null) {
                // a non active shard, ignore...
            } else if (shardResponse instanceof BroadcastShardOperationFailedException) {
                failedShards++;
                if (shardFailures == null) {
                    shardFailures = newArrayList();
                }
                shardFailures.add(new DefaultShardOperationFailedException(
                        (BroadcastShardOperationFailedException) shardResponse));
            } else {
                successfulShards++;
                if (shardResponse instanceof ShardTermwalkerResponse) {
                    ShardTermwalkerResponse shardResp = (ShardTermwalkerResponse) shardResponse;
                    String index = shardResp.getIndex();
                    int shardId = shardResp.getShardId();
                    // one map per index
                    Map indexresponse = (Map) response.get(index);
                    if (indexresponse == null) {
                        indexresponse = new HashMap();
                        response.put(index, indexresponse);
                    }
                    // shard-wise data
                    Map shardmap = shardResp.getResponse();

                    // Need to roll-up the shard responses into totals  

                    // first, sum all the numDocs
                    Integer numDocs = (Integer) indexresponse.get("num_docs");
                    if (numDocs == null) {
                        numDocs = 0;
                    }
                    indexresponse.put("num_docs", (Integer) shardmap.get("num_docs") + numDocs);

                    // estimate numTerms by taking the max
                    Integer numTerms = (Integer) indexresponse.get("num_terms");
                    if (numTerms == null) {
                        numTerms = 0;
                    }
                    indexresponse.put("num_terms", Math.max((Integer) shardmap.get("num_terms"), numTerms));

                    // add up the total term counts 
                    Long totalTerms = (Long) indexresponse.get("total_terms");
                    if (totalTerms == null) {
                        totalTerms = 0L;
                    }
                    indexresponse.put("total_terms", (Long) shardmap.get("total_terms") + totalTerms);

                    // shared df map
                    Map<String, Integer> dfMap = dfAccumulator.get(index);
                    if (dfMap == null) {
                        dfMap = new HashMap();
                        dfAccumulator.put(index, dfMap);
                    }

                    // shared ttf map
                    Map<String, Long> ttfMap = ttfAccumulator.get(index);
                    if (ttfMap == null) {
                        ttfMap = new HashMap();
                        ttfAccumulator.put(index, ttfMap);
                    }

                    // sum all the DFs from each shard
                    ArrayList<Map> topTerms = (ArrayList<Map>) shardmap.get("terms");
                    Long ttfTotalTerms = 0L;
                    for (Map term : topTerms) {
                        String key = (String) term.get("text");

                        if (includeDF) {
                            Integer docFreq = dfMap.get(key);
                            if (docFreq == null) {
                                docFreq = 0;
                            }
                            docFreq = docFreq + (Integer) term.get("df");
                            dfMap.put(key, docFreq);
                        }

                        if (includeTTF) {
                            Long totalTermFreq = ttfMap.get(key);
                            if (totalTermFreq == null) {
                                totalTermFreq = 0L;
                            }

                            Long count = (Long) term.get("ttf");

                            totalTermFreq = totalTermFreq + count;
                            ttfMap.put(key, totalTermFreq);
                            ttfTotalTerms += count;
                        }
                    }
                }
            }
        }

        // Finally, trim the accumulated totals to fit within the limits specified
        for (String index : dfAccumulator.keySet()) {
            // one map per index
            Map indexresponse = (Map) response.get(index);
            Integer numDocs = (Integer) indexresponse.get("num_docs");

            Long maxDocFreq = Math.max(Math.min(Math.round(numDocs * request.maxDocFrequency()), numDocs), 2);
            Long minDocFreq = Math.max(Math.min(Math.round(numDocs * request.minDocFrequency()), numDocs), 2);

            logger.info("termwalker:" + " max_doc_freq: " + maxDocFreq + " min_doc_freq: " + minDocFreq);

            indexresponse.put("max_doc_freq", maxDocFreq);
            indexresponse.put("min_doc_freq", minDocFreq);

            // results are always limited by this priorityqueue
            PriorityQueue<TermwalkerTerm> terms = new PriorityQueue<TermwalkerTerm>(termLimit) {
                @Override
                protected boolean lessThan(TermwalkerTerm a, TermwalkerTerm b) {
                    if (a.df > 0) {
                        return a.df < b.df;
                    } else {
                        return a.ttf < b.ttf;
                    }
                }
            };

            Map<String, Integer> dfMap = dfAccumulator.get(index);
            Map<String, Long> ttfMap = ttfAccumulator.get(index);

            indexresponse.put("num_terms_reduced", dfMap.size());

            Integer numTermsOver = 0;
            Integer numTermsUnder = 0;

            for (String text : dfMap.keySet()) {
                Integer docFreq = dfMap.get(text);

                if (docFreq > maxDocFreq) {
                    numTermsOver += 1;
                } else if (docFreq < minDocFreq) {
                    numTermsUnder += 1;
                } else {
                    // Go ahead and offer the result         
                    TermwalkerTerm term = new TermwalkerTerm(text, docFreq, ttfMap.get(text));

                    terms.insertWithOverflow(term);
                }
            }

            indexresponse.put("num_terms_over", numTermsOver);
            indexresponse.put("num_terms_under", numTermsUnder);
            indexresponse.put("df_num_terms", terms.size());
            indexresponse.put("ttf_num_terms", terms.size());

            if (includeDF) {
                List termList = new ArrayList();

                while (terms.size() > 0) {
                    TermwalkerTerm term = terms.pop();
                    Map tiMap = new HashMap();

                    tiMap.put("text", term.text);
                    if (includeDF) {
                        tiMap.put("df", term.df);
                    }
                    if (includeTTF) {
                        tiMap.put("ttf", term.ttf);
                    }
                    termList.add(tiMap);
                }

                indexresponse.put("terms", termList);
            }
        }
        return new TermwalkerResponse(shardsResponses.length(), successfulShards, failedShards, shardFailures)
                .setResponse(response);
    }

    @Override
    protected ShardTermwalkerRequest newShardRequest() {
        return new ShardTermwalkerRequest();
    }

    @Override
    protected ShardTermwalkerRequest newShardRequest(ShardRouting shard, TermwalkerRequest request) {
        return new ShardTermwalkerRequest(shard.index(), shard.id(), request);
    }

    @Override
    protected ShardTermwalkerResponse newShardResponse() {
        return new ShardTermwalkerResponse();
    }

    @Override
    protected ShardTermwalkerResponse shardOperation(ShardTermwalkerRequest request) throws ElasticSearchException {
        synchronized (mutex) {
            try {
                Map<String, Object> response = new HashMap();
                IndexService indexService = indicesService.indexServiceSafe(request.index());
                InternalIndexShard indexShard = (InternalIndexShard) indexService.shardSafe(request.shardId());
                Store store = indexShard.store();
                IndexReader reader = indexShard.searcher().reader();

                Integer termCount = 0;
                Long totalCount = 0L;
                List termList = new ArrayList();
                Fields fields = MultiFields.getFields(reader);
                Terms terms = fields.terms("_all");

                Boolean includeDF = request.includeDF();
                Boolean includeTTF = request.includeTTF();

                logger.info("termwalker:" + " shard: " + request.shardId() + " df: " + includeDF + " ttf: "
                        + includeTTF);

                if (terms != null) {
                    TermsEnum iterator = terms.iterator(null);

                    for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
                        Integer df = iterator.docFreq();
                        Long ttf = iterator.totalTermFreq();

                        termCount += 1;
                        totalCount += ttf;

                        if ((includeDF || includeTTF) && df > 1) {
                            Map tiMap = new HashMap();
                            tiMap.put("text", term.utf8ToString());
                            if (includeDF) {
                                tiMap.put("df", df);
                            }
                            if (includeTTF) {
                                tiMap.put("ttf", ttf);
                            }
                            termList.add(tiMap);
                        }
                    }
                } else {
                    logger.error("Terms for _all is null.");
                }
                response.put("terms", termList);
                response.put("num_docs", reader.numDocs());
                response.put("num_terms", termCount);
                response.put("total_terms", totalCount);

                return new ShardTermwalkerResponse(request.index(), request.shardId()).setResponse(response);
            } catch (IOException ex) {
                throw new ElasticSearchException(ex.getMessage(), ex);
            }
        }
    }

    @Override
    protected GroupShardsIterator shards(ClusterState clusterState, TermwalkerRequest request,
            String[] concreteIndices) {
        return clusterState.routingTable().activePrimaryShardsGrouped(concreteIndices, true);
    }

    @Override
    protected ClusterBlockException checkGlobalBlock(ClusterState state, TermwalkerRequest request) {
        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA);
    }

    @Override
    protected ClusterBlockException checkRequestBlock(ClusterState state, TermwalkerRequest request,
            String[] concreteIndices) {
        return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA, concreteIndices);
    }

}