ivory.server.RunRetrievalBroker.java Source code

Java tutorial

Introduction

Here is the source code for ivory.server.RunRetrievalBroker.java

Source

/*
 * Ivory: A Hadoop toolkit for web-scale information retrieval
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You may
 * obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0 
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package ivory.server;

import ivory.smrf.retrieval.Accumulator;

import java.io.IOException;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Logger;
import org.mortbay.jetty.Server;
import org.mortbay.jetty.servlet.Context;
import org.mortbay.jetty.servlet.ServletHolder;

import edu.umd.cloud9.io.FSProperty;
import edu.umd.cloud9.mapred.NullInputFormat;
import edu.umd.cloud9.mapred.NullMapper;
import edu.umd.cloud9.mapred.NullOutputFormat;

/**
 * @author Tamer Elsayed
 * @author Jimmy Lin
 */
public class RunRetrievalBroker extends Configured implements Tool {

    private static final Logger sLogger = Logger.getLogger(RunRetrievalBroker.class);

    private static Map<Integer, Integer> docnoToServerMapping = new HashMap<Integer, Integer>();

    private static class ServerMapper extends NullMapper {

        private String[] serverAddresses = null;

        public void run(JobConf conf, Reporter reporter) throws IOException {
            int port = 9999;

            String configPath = conf.get("ServerAddressPath");
            String scoreMergeModel = conf.get("ScoreMergeModel");

            FileSystem fs = FileSystem.get(conf);

            String hostName = InetAddress.getLocalHost().toString();
            String hostIP = "";
            int k = hostName.lastIndexOf("/");
            if (k >= 0 && k < hostName.length())
                hostIP = hostName.substring(k + 1);
            else {
                k = hostName.lastIndexOf("\\");
                if (k >= 0 && k < hostName.length())
                    hostIP = hostName.substring(k + 1);
                else
                    hostIP = hostName;
            }
            String fname = appendPath(configPath, "broker.brokerhost");
            sLogger.info("Writing host address to " + fname);
            sLogger.info("  address: " + hostIP + ":" + port);

            FSProperty.writeString(fs, fname, hostIP + ":" + port);

            sLogger.info("writing done.");
            sLogger.info("Score merging model: " + scoreMergeModel);

            if (!scoreMergeModel.equals("sort") && !scoreMergeModel.equals("normalize")) {
                throw new RuntimeException("Unsupported score mergeing model: " + scoreMergeModel);
            }

            String serverIDsStr = conf.get("serverIDs");

            sLogger.info("Host: " + InetAddress.getLocalHost().toString());
            sLogger.info("Port: " + port);
            sLogger.info("ServerAddresses: " + serverIDsStr);

            String[] serverIDs = serverIDsStr.split(";");

            serverAddresses = new String[serverIDs.length];
            for (int i = 0; i < serverIDs.length; i++) {
                fname = configPath + "/" + serverIDs[i] + ".host";
                serverAddresses[i] = FSProperty.readString(fs, fname);
            }

            Server server = new Server(port);
            Context root = new Context(server, "/", Context.SESSIONS);

            root.addServlet(
                    new ServletHolder(new QueryServlet(serverAddresses, docnoToServerMapping, scoreMergeModel)),
                    QueryServlet.ACTION);
            root.addServlet(
                    new ServletHolder(
                            new PlainTextQueryServlet(serverAddresses, docnoToServerMapping, scoreMergeModel)),
                    PlainTextQueryServlet.ACTION);
            root.addServlet(new ServletHolder(new BrokerFetchServlet(serverAddresses, docnoToServerMapping)),
                    BrokerFetchServlet.ACTION);
            root.addServlet(new ServletHolder(new HomeServlet()), "/");

            sLogger.info("Starting retrieval broker...");
            try {
                server.start();
                sLogger.info("Broker successfully started!");
            } catch (Exception e) {
                e.printStackTrace();
            }

            String s = InetAddress.getLocalHost().toString() + ":" + port;
            FSProperty.writeString(FileSystem.get(conf), appendPath(configPath, "broker.ready"), s);

            while (true)
                ;
        }
    }

    /**
     * Creates an instance of this tool.
     */
    public RunRetrievalBroker() {
    }

    private static int printUsage() {
        System.out.println("usage: [config-path] [score-merge-model]");
        ToolRunner.printGenericCommandUsage(System.out);
        return -1;
    }

    /**
     * Runs this tool.
     */
    public int run(String[] args) throws Exception {
        if (args.length != 2) {
            printUsage();
            return -1;
        }

        String configPath = args[0];

        FileSystem fs = FileSystem.get(getConf());

        String ids = "";

        sLogger.info("Starting retrieval broker...");
        sLogger.info("server config path: " + configPath);
        FileStatus[] stats = fs.listStatus(new Path(configPath));

        if (stats == null) {
            sLogger.info("Error: " + configPath + " not found!");
            return -1;
        }

        String scoreMergeModel = args[1];
        if (!scoreMergeModel.equals("sort") && !scoreMergeModel.equals("normalize")) {
            throw new RuntimeException("Unsupported score merging model: " + args[1]);
        }

        for (int i = 0; i < stats.length; i++) {
            String s = stats[i].getPath().toString();
            if (!s.endsWith(".host"))
                continue;

            String sid = s.substring(s.lastIndexOf("/") + 1, s.lastIndexOf(".host"));
            sLogger.info("sid=" + sid + ", host=" + s);

            if (ids.length() != 0)
                ids += ";";

            ids += sid;
        }

        JobConf conf = new JobConf(RunRetrievalBroker.class);
        conf.setJobName("RetrievalBroker");

        conf.setNumMapTasks(1);
        conf.setNumReduceTasks(0);

        conf.setInputFormat(NullInputFormat.class);
        conf.setOutputFormat(NullOutputFormat.class);
        conf.setMapperClass(ServerMapper.class);

        conf.set("serverIDs", ids);
        conf.set("ServerAddressPath", configPath);
        conf.set("ScoreMergeModel", scoreMergeModel);
        conf.set("mapred.child.java.opts", "-Xmx2048m");

        fs.delete(new Path(appendPath(configPath, "broker.ready")), true);

        JobClient client = new JobClient(conf);
        client.submitJob(conf);

        sLogger.info("broker started!");

        while (true) {
            String f = appendPath(configPath, "broker.ready");
            if (fs.exists(new Path(f))) {
                break;
            }

            Thread.sleep(5000);
        }

        String s = FSProperty.readString(FileSystem.get(conf), appendPath(configPath, "broker.ready"));
        sLogger.info("broker ready at " + s);

        return 0;
    }

    private static String appendPath(String base, String file) {
        return base + (base.endsWith("/") ? "" : "/") + file;
    }

    /**
     * Dispatches command-line arguments to the tool via the
     * <code>ToolRunner</code>.
     */
    public static void main(String[] args) throws Exception {
        int res = ToolRunner.run(new Configuration(), new RunRetrievalBroker(), args);
        System.exit(res);
    }

    public static class QueryServlet extends HttpServlet {
        private static final long serialVersionUID = -5998786589277554550L;

        public static final String ACTION = "/search";
        public static final String QUERY_FIELD = "query";

        private String[] serverAddresses;
        private Map<Integer, Integer> docnoToServerMapping = null;
        private String scoreMergeModel = "";

        public QueryServlet(String[] addresses, Map<Integer, Integer> mapping, String model) {
            serverAddresses = addresses;
            docnoToServerMapping = mapping;
            scoreMergeModel = model;
        }

        public void doGet(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException {
            doPost(req, res);
        }

        public void doPost(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException {
            sLogger.info("Triggered servlet for running queries");
            res.setContentType("text/html");
            PrintWriter out = res.getWriter();

            String query = null;
            if (req.getParameterValues("query") != null)
                query = req.getParameterValues("query")[0];

            sLogger.info("Raw query: " + query);

            long startTime = System.currentTimeMillis();
            ServerThread[] servers = new ServerThread[serverAddresses.length];
            Thread[] threads = new Thread[serverAddresses.length];
            for (int i = 0; i < serverAddresses.length; i++) {
                servers[i] = new ServerThread(serverAddresses[i], query);
                threads[i] = new Thread(servers[i]);
                threads[i].start();
            }
            try {
                for (Thread thread : threads) {
                    thread.join();
                }
                sLogger.info("All servers: done.");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

            sLogger.info("Score merging model: " + scoreMergeModel);

            Accumulator[] results = new Accumulator[0];
            for (int i = 0; i < servers.length; i++) {
                Accumulator[] serverResults = null;
                if (scoreMergeModel.equals("sort")) {
                    serverResults = servers[i].getResults();
                } else {
                    serverResults = servers[i].getZNormalizedResults();
                }

                if (docnoToServerMapping != null) {
                    for (Accumulator a : serverResults)
                        docnoToServerMapping.put(a.docno, i);
                }
                results = mergeScores(results, serverResults);
            }

            String formattedOutput = getFormattedResults(results, servers);
            long endTime = System.currentTimeMillis();
            sLogger.info("query execution time (ms): " + (endTime - startTime));

            out.println(formattedOutput);
            out.close();
        }

        protected String getFormattedResults(Accumulator[] results, ServerThread[] servers) {
            StringBuffer sb = new StringBuffer();
            sb.append("<html><head><title>Threaded Broker Results</title></head>\n<body>");

            sb.append("<ol>");
            for (Accumulator a : results) {
                sb.append("<li>docno <a href=" + BrokerFetchServlet.formatRequestURL(a.docno) + ">" + a.docno
                        + "</a> (" + a.score + ")</li>\n");
            }
            sb.append("</ol>");
            sb.append("</body></html>\n");

            return sb.toString();
        }

        private Accumulator[] mergeScores(Accumulator[] iScores, Accumulator[] jScores) {
            // assuming that scored documents are mutual exclusive
            Accumulator[] results = new Accumulator[iScores.length + jScores.length];
            int i = 0, j = 0, k = 0;
            while (i < iScores.length && j < jScores.length) {
                if (iScores[i].score > jScores[j].score) {
                    results[k] = iScores[i];
                    i++;
                } else {
                    results[k] = jScores[j];
                    j++;
                }
                k++;
            }
            while (i < iScores.length) {
                results[k] = iScores[i];
                i++;
                k++;
            }

            while (j < jScores.length) {
                results[k] = jScores[j];
                j++;
                k++;
            }

            return results;
        }

        protected static class ServerThread implements Runnable {

            String address;
            String query;
            String textResults = null;
            HashMap<Integer, String> docnoMapping = new HashMap<Integer, String>();

            public ServerThread() {
            }

            public ServerThread(String addr, String q) {
                address = addr;
                query = q;
            }

            public void set(String addr, String q) {
                address = addr;
                query = q;
            }

            public String getOriginalDocid(int docno) {
                return docnoMapping.get(docno);
            }

            public String getTextResults() {
                return textResults;
            }

            public Accumulator[] getZNormalizedResults() {
                float sum = 0, sumSq = 0;
                if (textResults == null)
                    return null;
                String[] lines = textResults.split("\t");
                Accumulator[] results = new Accumulator[lines.length / 3];
                int i = 0;
                int j = 0;
                while (i < lines.length) {
                    int docid = -1;
                    try {
                        docid = Integer.parseInt(lines[i]);
                    } catch (NumberFormatException e) {
                        i++;
                        continue;
                    }
                    i++;
                    float score = Float.parseFloat(lines[i]);
                    sum += score;
                    sumSq += score * score;
                    i++;
                    String originalDocID = lines[i];
                    docnoMapping.put(new Integer(docid), originalDocID);
                    i++;
                    results[j] = new Accumulator(docid, score);
                    j++;

                }
                int n = results.length;
                float muo = sum / n;
                float sigma = (float) Math.sqrt((sumSq - n * muo * muo) / (n - 1));

                for (Accumulator a : results) {
                    a.score = (a.score - muo) / sigma;
                }
                sLogger.info("returning z-normalized scores.");
                return results;
            }

            public Accumulator[] getMaxMinNormalizedResults() {
                float min = Float.MAX_VALUE, max = Float.MIN_VALUE;
                if (textResults == null)
                    return null;
                String[] lines = textResults.split("\t");
                Accumulator[] results = new Accumulator[lines.length / 3];
                int i = 0;
                int j = 0;
                while (i < lines.length) {
                    int docid = -1;
                    try {
                        docid = Integer.parseInt(lines[i]);
                    } catch (NumberFormatException e) {
                        i++;
                        continue;
                    }
                    i++;
                    float score = Float.parseFloat(lines[i]);
                    if (score > max)
                        max = score;
                    else if (score < min)
                        min = score;
                    i++;
                    String originalDocID = lines[i];
                    docnoMapping.put(new Integer(docid), originalDocID);
                    i++;
                    results[j] = new Accumulator(docid, score);
                    j++;

                }
                float d = max - min;
                for (Accumulator a : results) {
                    a.score = (a.score - min) / d;
                }
                sLogger.info("returning max/min normalized scores.");
                return results;
            }

            public Accumulator[] getResults() {
                if (textResults == null)
                    return null;
                String[] lines = textResults.split("\t");
                Accumulator[] results = new Accumulator[lines.length / 3];
                int i = 0;
                int j = 0;
                while (i < lines.length) {
                    int docid = -1;
                    try {
                        docid = Integer.parseInt(lines[i]);
                    } catch (NumberFormatException e) {
                        i++;
                        continue;
                    }
                    i++;
                    float score = Float.parseFloat(lines[i]);
                    i++;
                    String originalDocID = lines[i];
                    docnoMapping.put(new Integer(docid), originalDocID);
                    i++;
                    results[j] = new Accumulator(docid, score);
                    j++;
                }
                sLogger.info("returning original scores.");
                return results;
            }

            public void run() {
                try {
                    String url = "http://" + address + RetrievalServer.QueryBrokerServlet.ACTION + "?"
                            + RetrievalServer.QueryBrokerServlet.QUERY_FIELD + "=" + query.replaceAll(" ", "+");

                    sLogger.info("fetching " + url);

                    textResults = HttpUtils.fetchURL(new URL(url));
                    sLogger.info(Thread.currentThread().getName() + "-" + address + ": done.");
                    docnoMapping.clear();
                } catch (MalformedURLException e) {
                    e.printStackTrace();
                }
            }

        }

    }

    public static class PlainTextQueryServlet extends QueryServlet {
        private static final long serialVersionUID = -5998786589277554554L;
        public static final String ACTION = "/psearch";

        public PlainTextQueryServlet(String[] addresses, Map<Integer, Integer> mapping, String model) {
            super(addresses, mapping, model);
        }

        protected String getFormattedResults(Accumulator[] results, ServerThread[] servers) {
            StringBuffer sb = new StringBuffer();
            int k = 0;
            for (Accumulator a : results) {
                String origDocID = getOriginalDocID(a.docno, servers);
                if (origDocID == null) {
                    sLogger.info("Docno not found in all servers: " + a.docno + " !!");
                }
                sb.append(a.docno + "\t" + a.score + "\t" + origDocID + "\n");
                k++;
                //if (k >= 2000)
                if (k >= 10000)
                    break;
            }
            return sb.toString();
        }

        private String getOriginalDocID(int docno, ServerThread[] servers) {
            String s = "";
            for (ServerThread server : servers) {
                s = server.getOriginalDocid(docno);
                if (s != null)
                    return s;
            }
            return null;
        }

    }

    public static class HomeServlet extends HttpServlet {
        private static final long serialVersionUID = 7368950575963429946L;

        protected void doGet(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)
                throws ServletException, IOException {
            httpServletResponse.setContentType("text/html");
            PrintWriter out = httpServletResponse.getWriter();

            out.println("<html><head><title>Ivory Search Interface</title><head>");
            out.println("<body>");
            out.println("<h3>Run a query:</h3>");
            out.println("<form method=\"post\" action=\"" + QueryServlet.ACTION + "\">");
            out.println("<input type=\"text\" name=\"" + QueryServlet.QUERY_FIELD + "\" size=\"60\" />");
            out.println("<input type=\"submit\" value=\"Run query!\" />");
            out.println("</form>");
            out.println("</p>");

            out.print("</body></html>\n");

            out.close();
        }
    }

    public static class BrokerFetchServlet extends HttpServlet {
        private static final long serialVersionUID = -5998986589277554550L;

        public static final String ACTION = "/BrokerFetch";
        public static final String DOCNO_FIELD = "docno";

        private String[] serverAddresses;

        private Map<Integer, Integer> docnoToServerMapping = null;

        public BrokerFetchServlet(String[] addresses, Map<Integer, Integer> mapping) {
            serverAddresses = addresses;
            docnoToServerMapping = mapping;
        }

        public void doGet(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException {
            doPost(req, res);
        }

        public void doPost(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException {
            sLogger.info("Triggered servlet for fetching a document");
            res.setContentType("text/html");
            PrintWriter out = res.getWriter();

            String docno = null;
            if (req.getParameterValues(DOCNO_FIELD) != null)
                docno = req.getParameterValues(DOCNO_FIELD)[0];

            sLogger.info("Raw query: " + docno);

            Integer serverNo = docnoToServerMapping.get(Integer.parseInt(docno));
            if (serverNo == null) {
                sLogger.info("document not found in results/mapping-table!!");
                return;
            }

            long startTime = System.currentTimeMillis();
            String document = HttpUtils.fetchURL(
                    new URL("http://" + this.serverAddresses[serverNo] + RetrievalServer.FetchDocnoServlet.ACTION
                            + "?" + RetrievalServer.FetchDocnoServlet.DOCNO + "=" + docno));
            long endTime = System.currentTimeMillis();
            sLogger.info("document fetched in time (ms): " + (endTime - startTime));
            out.println(document);
            out.close();
        }

        public static String formatRequestURL(int docno) {
            return ACTION + "?" + DOCNO_FIELD + "=" + new Integer(docno).toString();
        }
    }
}