searching.QueryExpansion.java Source code

Java tutorial

Introduction

Here is the source code for searching.QueryExpansion.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */

package searching;

import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.similarities.SPUDLMSimilarity;
import utils.Utils;

/**
 *
 * @author ronanc
 */
public class QueryExpansion {

    private final static Logger logger = Logger.getLogger(QueryExpansion.class.getName());

    public static boolean EXPAND = false;
    public static int method = 1;

    // all methods implemented
    public final static int RM3 = 1;
    public final static int SPUDQTM = 2;
    public final static int SPUDQTM2 = 9;
    public final static int DIRQTM = 3;
    public final static int PDCM = 4;
    public final static int SMM = 5;
    public final static int PRM1 = 6;
    public final static int PRM2 = 7;

    //# of pseudo relevant docs to consider
    public static int pseudo_rel_docs = 10;

    //# of expansion word types (dimsensions to add)
    public static int num_expansion_terms = 25;

    //distance parameter in PRM models
    public static double prm_sigma = 200.0;

    //lambda parameter in SMM (Default should be low)
    public static double ssm_lambda = 0.2;

    //Default parameter 0.8 works great
    public static double spud_omega = 0.8;

    //mass of expansion terms compared to original query
    public static double interpolation = 0.5;

    public TreeMap<String, Double>[] pdocs;
    public Integer[] pdoc_lengths;
    public Double[] pdoc_positional_lengths; // can be used if one needs PRM2
    public Double[] pdoc_scores;

    //actual docs used
    private int actual_pdocs;

    public QueryExpansion() {
        pdocs = new TreeMap[pseudo_rel_docs];
        pdoc_lengths = new Integer[pseudo_rel_docs];
        pdoc_positional_lengths = new Double[pseudo_rel_docs];
        pdoc_scores = new Double[pseudo_rel_docs];

        actual_pdocs = 0;

    }

    /**
     * 
     * store frequencies of top docs in maps
     * 
     * @param text
     * @param doc_score
     * @param analyzer
     * @param reader
     * @throws IOException 
     */
    public void addExpansionDoc(String text, double doc_score, Analyzer analyzer, IndexReader reader)
            throws IOException {

        if (actual_pdocs < QueryExpansion.pseudo_rel_docs) {

            TreeMap<String, Double> map = new TreeMap();

            Integer length = 0;
            Double f;

            String term;
            TokenStream ts = analyzer.tokenStream("myfield", new StringReader(text));
            //OffsetAttribute offsetAtt = ts.addAttribute(OffsetAttribute.class);
            try {
                ts.reset(); // Resets this stream to the beginning. (Required)
                while (ts.incrementToken()) {
                    term = ts.getAttribute(CharTermAttribute.class).toString();
                    if (term.length() > 1) {

                        f = map.get(term);

                        if (f == null) {
                            map.put(term, 1.0);
                        } else {
                            map.put(term, f + 1.0);
                        }
                        length++;
                    }
                }
                ts.end();
            } finally {
                ts.close();
            }

            pdocs[actual_pdocs] = map;
            pdoc_lengths[actual_pdocs] = length;
            pdoc_scores[actual_pdocs] = doc_score;

            //logger.info(observed_bg_mass[iter] + "\t" + (1-observed_bg_prob));
            actual_pdocs++;
        }
    }

    /**
     * calculate positional relevance weights
     * 
     * @param query
     * @param text
     * @param doc_score
     * @param analyzer
     * @param reader
     * @throws IOException 
     */

    public void addPositionalExpansionDoc(CustomQuery query, String text, double doc_score, Analyzer analyzer,
            IndexReader reader) throws IOException {

        //System.out.println(query);
        //System.out.println(text);

        if (actual_pdocs < QueryExpansion.pseudo_rel_docs) {

            TreeMap<String, ArrayList<Long>> query_term_pos = new TreeMap();
            Integer length = 0;

            Long pos = 1L;
            String term;
            TokenStream ts = analyzer.tokenStream("myfield", new StringReader(text));

            ArrayList<Long> qpos;
            //OffsetAttribute offsetAtt = ts.addAttribute(OffsetAttribute.class);
            try {
                ts.reset(); // Resets this stream to the beginning. (Required)
                while (ts.incrementToken()) {
                    term = ts.getAttribute(CharTermAttribute.class).toString();
                    if (term.length() > 1) {

                        //System.out.print(pos + ":" + term + " ");
                        if (query.contains(term)) {
                            qpos = query_term_pos.get(term);
                            if (qpos == null) {
                                qpos = new ArrayList<>();
                            }
                            qpos.add(pos);
                            query_term_pos.put(term, qpos);
                        }

                        length++;
                        pos++;
                    }
                }
                ts.end();
            } finally {
                ts.close();
            }

            //
            // All positions collected
            // now iterate over the document again to get weights
            //
            //System.out.println("Doc length" + text.length());
            //System.out.println("Positions... ");
            //System.out.println(query_term_pos.toString());
            //System.out.println("END...");
            TreeMap<String, Double> map = new TreeMap();
            Double f;
            pos = 1L;
            double w, w_norm, prob, f0;
            Double pos_length = 0.0;
            Double sum_df = (double) reader.getSumDocFreq("text");
            double spud_pi = SPUDLMSimilarity.b0 * SPUDLMSimilarity.omega
                    / (query_term_pos.size() * (1 - SPUDLMSimilarity.omega)
                            + SPUDLMSimilarity.b0 * SPUDLMSimilarity.omega);
            Double df;
            double dist;

            ts = analyzer.tokenStream("myfield", new StringReader(text));
            try {
                ts.reset(); // Resets this stream to the beginning. (Required)
                while (ts.incrementToken()) {
                    term = ts.getAttribute(CharTermAttribute.class).toString();
                    if (term.length() > 1) {

                        prob = 0.0;
                        //f is occurrence
                        w_norm = Math.sqrt(2 * Math.PI * prm_sigma * prm_sigma);
                        for (String qt : query_term_pos.keySet()) {
                            ArrayList<Long> pos_list = query_term_pos.get(qt);
                            w = 1.0;
                            df = (double) reader.docFreq(new Term("text", qt));
                            for (Long p : pos_list) {
                                dist = ((pos - p) * (pos - p)) / (2 * prm_sigma * prm_sigma);
                                f0 = Math.exp(-dist);

                                //if (QueryExpansion.method == QueryExpansion.PRM2QTM){
                                //w += (((double) ((1 - spud_pi) * f0) / (((1 - spud_pi) * f0 ) + spud_pi * (df / sum_df))));
                                //    w += f0;
                                //}else{
                                w += f0;
                                //}

                            }
                            //System.out.println("weight " + w );
                            prob += Math.log(w / w_norm);
                        }

                        //System.out.print(pos + "\t" + term + "\t" +  Math.exp(prob) + "\n");

                        /** sum of the probabilities over the positional terms in the documents*/
                        f = map.get(term);
                        if (f == null) {
                            map.put(term, Math.exp(prob));
                        } else {
                            map.put(term, f + Math.exp(prob));
                        }
                        pos_length += Math.exp(prob);
                        pos++;
                    }
                }
                ts.end();
            } finally {
                ts.close();
            }

            double sum = 0.0;
            for (String word : map.keySet()) {
                //logger.info(word + "\t" + map.get(word)/pos_length);
                sum += map.get(word) / pos_length;
            }
            //logger.info("sum is " + sum);

            pdocs[actual_pdocs] = map;
            pdoc_lengths[actual_pdocs] = length;
            pdoc_positional_lengths[actual_pdocs] = pos_length;
            pdoc_scores[actual_pdocs] = doc_score;

            actual_pdocs++;
        }
    }

    /**
     * 
     * @return 
     */
    public Map<String, Double> expansionTerms(IndexReader reader, CustomQuery query, int query_length)
            throws IOException {

        Map<String, Double> expansion_terms = new TreeMap();
        Map<String, Double> map;
        Double f;
        Double e, prob;
        Double df;
        Double sum_df = (double) reader.getSumDocFreq("text");
        Double cf;
        Double sum_cf = (double) reader.getSumTotalTermFreq("text");

        Double score_norm = 0.0;
        if (QueryExpansion.method == QueryExpansion.PDCM) {

            //logger.info(actual_pdocs + " docs" + this.pdocs.length);
            //expansion_terms = this.DCM().estimateDCM();
            //not implemented here

        } else if (QueryExpansion.method == QueryExpansion.SMM) {

            //get SMM estimates

            expansion_terms = this.SMM(reader, 20);

        } else {

            for (int i = 0; i < pseudo_rel_docs; i++) {

                map = this.pdocs[i];
                if (map != null) {

                    double spud_pi = SPUDLMSimilarity.b0 * QueryExpansion.spud_omega
                            / (map.size() * (1 - QueryExpansion.spud_omega)
                                    + SPUDLMSimilarity.b0 * QueryExpansion.spud_omega);
                    double dir_pi = SPUDLMSimilarity.dir_mu / (this.pdoc_lengths[i] + SPUDLMSimilarity.dir_mu);

                    for (String term : map.keySet()) {

                        double tf = (double) map.get(term);

                        if (!term.contains(":")) {
                            df = (double) reader.docFreq(new Term("text", term));
                            cf = (double) reader.totalTermFreq(new Term("text", term));
                            //logger.info(new Term(term) + "\t" + df + "\t" + sum_df);
                            //RM3
                            if (QueryExpansion.method == QueryExpansion.RM3) {
                                //RM3 with u=0
                                e = ((double) tf / this.pdoc_lengths[i]) * Math.exp(this.pdoc_scores[i]);

                                //e = ((1-spud_pi)*((double) tf / this.pdoc_lengths[i]) +  spud_pi*(df / sum_df)) * Math.exp(this.pdoc_scores[i]);

                            } else if (QueryExpansion.method == QueryExpansion.DIRQTM) {
                                //Dir Topic Model
                                e = (((double) ((1 - dir_pi) * tf / this.pdoc_lengths[i])
                                        / (((1 - dir_pi) * tf / this.pdoc_lengths[i]) + dir_pi * (cf / sum_cf))))
                                        * Math.exp(this.pdoc_scores[i]);
                            } else if ((QueryExpansion.method == QueryExpansion.SPUDQTM)
                                    || (QueryExpansion.method == QueryExpansion.SPUDQTM2)) {
                                //SPUD Topic Model
                                prob = (((double) ((1 - spud_pi) * tf / this.pdoc_lengths[i])
                                        / (((1 - spud_pi) * tf / this.pdoc_lengths[i]) + spud_pi * (df / sum_df))));
                                e = prob * Math.exp(this.pdoc_scores[i]);
                            } else if (QueryExpansion.method == QueryExpansion.PRM1) {
                                //Positional Relevance Model 1
                                e = ((double) tf / this.pdoc_lengths[i]) * Math.exp(this.pdoc_scores[i]);
                            } else if (QueryExpansion.method == QueryExpansion.PRM2) {
                                //Positional Relevance Model 2
                                e = ((double) tf / this.pdoc_positional_lengths[i]) * Math.exp(this.pdoc_scores[i]);
                            } else {
                                //default RM3
                                e = ((double) tf / this.pdoc_lengths[i]) * Math.exp(this.pdoc_scores[i]);
                            }

                            f = expansion_terms.get(term);
                            if (f == null) {
                                expansion_terms.put(term, e);
                            } else {
                                expansion_terms.put(term, e + f);
                            }
                        }

                    }

                    score_norm += Math.exp(this.pdoc_scores[i]);
                    //logger.info(i + "\t" + Math.exp(this.pdoc_scores[i]));

                }

            }

        }

        Double norm = 0.0, topic_prob;
        Double topical_mass = 0.0;
        int t = 0;
        //sort
        ArrayList list = sortValue(expansion_terms);

        //create query-topic_model for QTM probability
        TreeMap<String, Double> query_topic_model = new TreeMap();
        for (int i = 0; (i < num_expansion_terms) && (i < list.size()); i++) {

            Double tsv = (double) ((Map.Entry) list.get(i)).getValue();
            String term = ((Map.Entry) list.get(i)).getKey().toString();
            topic_prob = tsv / score_norm;
            topical_mass += topic_prob;

            norm += tsv;
            t++;

            query_topic_model.put(term, topic_prob);
            //System.out.println(term + "\t" + topic_prob + "\t" +  (double)((Map.Entry)list.get(i)).getValue());
        }

        /*
        if (QueryExpansion.method == QueryExpansion.SPUDQTM2){
        Double gen = this.QueryModelLikelihood(reader, query, query_topic_model);
        logger.info("Topic score " + gen + "\t" + query.mass());
        QueryExpansion.interpolation =  gen;
        }
        */

        //now just grab the selected terms and normalised to sum to 1.0
        TreeMap<String, Double> selected_terms = new TreeMap();
        double sum = 0;
        for (int i = 0; (i < t) && (i < list.size()); i++) {

            f = (double) ((Map.Entry) list.get(i)).getValue();
            ((Map.Entry) list.get(i)).setValue(f / norm);
            selected_terms.put(((Map.Entry) list.get(i)).getKey().toString(), f / norm);
            sum += f / norm;
        }

        return selected_terms;
    }

    /**
     * smooth the expansion distribution with the original query
     * @param query
     * @param e_terms 
     */
    public void expandCustomQuery(IndexReader reader, CustomQuery query, Map<String, Double> e_terms)
            throws IOException {

        double expansion_mass;

        expansion_mass = query.mass() * QueryExpansion.interpolation / (1.0 - QueryExpansion.interpolation);

        Double f;
        for (String term : e_terms.keySet()) {

            //add the expansion term to the query object
            query.add(term, (expansion_mass * e_terms.get(term)));
        }

        //logger.info(query.examine());

    }

    /**
     * return list of sorted terms
     * @param t
     * @return 
     */

    public static ArrayList sortValue(Map<String, Double> t) {

        //Transfer as List and sort it
        ArrayList<Map.Entry<String, Double>> l = new ArrayList(t.entrySet());
        Collections.sort(l, new Comparator<Map.Entry<String, Double>>() {

            public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
                return o2.getValue().compareTo(o1.getValue());
            }
        });

        return l;
    }

    /**
     * 
     * estimate the SMM using Expectation Maximization for Multinomial
     * 
     * @return 
     */
    private Map<String, Double> SMM(IndexReader reader, int iterations) throws IOException {

        double avg_dl = 0.0;
        double mass = 0.0;
        for (int i = 0; i < this.actual_pdocs; i++) {

            mass += (double) pdoc_lengths[i];
        }

        //double lambda = 0.0 ;
        //get initial estimate counts

        Map<String, Double> counts = new TreeMap();
        Double f, est;
        for (int i = 0; i < this.actual_pdocs; i++) {

            if (pdocs[i] != null) {
                for (String term : this.pdocs[i].keySet()) {
                    f = this.pdocs[i].get(term);

                    est = counts.get(term);

                    if (est == null) {
                        counts.put(term, f);
                    } else {
                        counts.put(term, f + est);
                    }
                }
            }
        }

        //now we have initial estimates of the maximum likelhood multinomial
        //use EM to find likelihood given the background model and fixed mixture parameter

        TreeMap<String, Double> rel_likelihoods = new TreeMap();
        Double cf, ptF, ptC, rl, co;
        Double sum_cf = (double) reader.getSumTotalTermFreq("text");

        for (int i = 0; i < iterations; i++) {

            //E-step (update relative likelihoods)
            for (String w : counts.keySet()) {
                cf = (double) reader.totalTermFreq(new Term("text", w));
                ptF = (1 - ssm_lambda) * counts.get(w) / mass;
                ptC = (ssm_lambda) * cf / sum_cf;
                rl = ptF / (ptF + ptC);
                rel_likelihoods.put(w, rl);
            }

            //M-step (recalculate max-likelihood of estimates given relative likelihoods)
            mass = 0.0;
            for (String w : counts.keySet()) {
                co = counts.get(w);
                rl = rel_likelihoods.get(w);
                mass += co * rl;
                counts.put(w, co * rl);
            }

            //logger.info("iter " + i + "\t"  + mass + " ");

        }

        //normalise partial count vector by updated mass and return

        for (String w : counts.keySet()) {
            counts.put(w, counts.get(w) / mass);
        }

        return counts;

    }

    /**
     * What is the likelihood of the Query Model given the original query
     * 
     * @param reader
     * @param query
     * @param e_terms
     * @return
     * @throws IOException 
     */
    /*
    private double QueryModelLikelihood(IndexReader reader,CustomQuery query, Map<String,Double> e_terms) throws IOException{
        
    double prob=0.0;
        
    for (String term:query.keySet()){
            
        Double p = e_terms.get(term);
            
            
        if (p == null){
            p = 0.5;
        }
            
        prob = prob * p;
            
    }
    return prob;
        
    }
    */

}