approxnn.ANNRetriever.java Source code

Java tutorial

Introduction

Here is the source code for approxnn.ANNRetriever.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package approxnn;

import indexer.Cell;
import indexer.DocVector;
import indexer.QueryVector;
import java.io.File;
import java.io.FileReader;
import java.util.List;
import java.util.Properties;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.similarities.LMJelinekMercerSimilarity;
import sift.IndexedVecQueries;
import sift.VecQueries;
import org.apache.lucene.store.MMapDirectory;
import rndvecgen.RandomQueryGen;

/**
 *
 * @author Debasis
 */
public class ANNRetriever {

    protected Properties prop;
    protected IndexReader reader; // the combined index to search
    protected IndexSearcher searcher;
    protected int numDimensions;
    protected int numIntervals;
    VecQueries indexedVecQueries;
    //VecQueries vecQueries;
    protected boolean debug;
    protected int subSpaceDimension;
    protected boolean syntheticQueries;
    protected int start, end;
    RandomQueryGen rqgen;

    public ANNRetriever(String propFile) throws Exception {
        prop = new Properties();
        prop.load(new FileReader(propFile));
        numDimensions = Integer.parseInt(prop.getProperty("vec.numdimensions"));

        syntheticQueries = prop.getProperty("data.source").equals("synthetic");

        if (syntheticQueries)
            rqgen = new RandomQueryGen(prop);
        // Read from optimized index (instead of the initial index)
        String indexPath = !syntheticQueries ? prop.getProperty("index") : rqgen.randomSamplesFileName() + ".index";

        if (indexPath != null) {
            File indexDir = new File(indexPath);

            //reader = DirectoryReader.open(FSDirectory.open(indexDir.toPath()));
            reader = DirectoryReader.open(MMapDirectory.open(indexDir.toPath()));
            //reader = DirectoryReader.open(new RAMDirectory(FSDirectory.open(indexDir.toPath()), IOContext.DEFAULT));

            searcher = new IndexSearcher(reader);
            searcher.setSimilarity(new LMJelinekMercerSimilarity(0.1f)); // almost close to tf        
        }
        DocVector.initVectorRange(prop);
        numIntervals = DocVector.numIntervals;

        if (!syntheticQueries)
            indexedVecQueries = new IndexedVecQueries(propFile);
        //System.out.println(indexedVecQueries);

        //vecQueries = new VecQueries(propFile);
        debug = Boolean.parseBoolean(prop.getProperty("debug", "false"));
        subSpaceDimension = Integer.parseInt(prop.getProperty("subspace.dimension", "0"));

        start = Integer.parseInt(prop.getProperty("retrieve.start", "0"));
        end = Integer.parseInt(prop.getProperty("retrieve.end", "-1"));
    }

    public void close() throws Exception {
        reader.close();
    }

    // For every dimension of the query vector,
    // compute the set of vectors within the epsilon neighborhood
    // of the projected vectors in that dimension.
    // Call this set N(epsilon)
    // Will need to merge N(0) union N(1) union ... N(max)
    // Note that these sets are mutually disjoint
    public ANNList retrieve(DocVector qvec, int epsilon) throws Exception {
        Query subspaceQuery;
        TopDocs retrDocs = null;

        Cell[] cells = qvec.getCells();

        final int N = reader.numDocs();
        int M = Integer.parseInt(prop.getProperty("nwanted", "0"));
        if (M == 0)
            M = N / numIntervals;

        boolean weightedQ = Boolean.parseBoolean(prop.getProperty("query.weighted", "false"));
        float sigma = Float.parseFloat(prop.getProperty("sigma", "0.1"));
        float projectionThreshold = Float.parseFloat(prop.getProperty("proj.threshold", "1"));

        ANNList union = null;
        ANNList localIntersection = null;
        int j, k;

        for (j = 0; j < numDimensions;) {
            localIntersection = null;
            ANNList intersection = null;

            // projected subspace (take intersections across each dimension)            
            for (k = j; k < numDimensions; k++) {
                Cell cell = cells[k];

                subspaceQuery = weightedQ ? cell.constructWeightedQuery(epsilon, sigma)
                        : cell.constructQuery(epsilon);

                retrDocs = searcher.search(subspaceQuery, M);

                if (localIntersection == null) {
                    localIntersection = new ANNList(retrDocs);
                } else {
                    ANNList thisDimDist = new ANNList(retrDocs);
                    intersection = ANNList.getIntersection(localIntersection, thisDimDist);

                    if (intersection.neighbors.size() < projectionThreshold) {
                        if (debug) {
                            System.out.println("Took projection: [" + j + ", " + (k - 1) + "] "
                                    + localIntersection.neighbors.size());
                        }
                        break;
                    } else {
                        localIntersection = intersection;
                    }
                }
            }

            if (debug) {
                if (k == numDimensions) {
                    System.out.println(
                            "Took projection: [" + j + ", " + (k - 1) + "] " + localIntersection.neighbors.size());
                }

            }
            j = k;

            if (union == null) {
                union = new ANNList(localIntersection.neighbors);
            } else {
                union = ANNList.getUnion(union, localIntersection);
            }
        }

        if (debug) {
            System.out.println("#Points considered: " + union.neighbors.size());
        }

        return union;
    }

    public ANNList retrieveFixedProjections(DocVector qvec, int epsilon) throws Exception {
        Query subspaceQuery;
        TopDocs retrDocs = null;

        Cell[] cells = qvec.getCells();

        final int N = reader.numDocs();
        final int M = N / numIntervals;

        boolean weightedQ = Boolean.parseBoolean(prop.getProperty("query.weighted", "false"));
        float sigma = Float.parseFloat(prop.getProperty("sigma", "0.1"));

        int numSubSpaces;
        numSubSpaces = numDimensions / subSpaceDimension;
        assert (numSubSpaces * subSpaceDimension == numDimensions);

        ANNList union = null;
        ANNList localIntersection = null;
        int j, k;

        for (j = 0; j < numDimensions; j += subSpaceDimension) {
            localIntersection = null;
            ANNList intersection = null;

            // projected subspace (take intersections across each dimension)            
            for (k = 0; k < subSpaceDimension; k++) {
                Cell cell = cells[j + k];

                subspaceQuery = weightedQ ? cell.constructWeightedQuery(epsilon, sigma)
                        : cell.constructQuery(epsilon);

                retrDocs = searcher.search(subspaceQuery, M);

                if (localIntersection == null) {
                    localIntersection = new ANNList(retrDocs);
                } else {
                    ANNList thisDimDist = new ANNList(retrDocs);
                    localIntersection = ANNList.getIntersection(localIntersection, thisDimDist);
                }
            }

            //localIntersection = new ANNList(localIntersection.selectTop(qvec, reader));
            // Select the top from this intersection set by similarity scores
            //localIntersection = new ANNList(localIntersection.selectTopKSim(qvec, reader, 100));
            //if (localIntersection.neighbors.size() > 0)
            //    localIntersection = new ANNList(localIntersection.selectTop(qvec, reader));
            if (debug) {
                System.out.println("Subspace [" + j + ", " + (j + subSpaceDimension - 1) + "]" + ": "
                        + localIntersection.neighbors.size());
            }

            if (union == null) {
                union = new ANNList(localIntersection.neighbors);
            } else {
                union = ANNList.getUnion(union, localIntersection);
            }
        }

        if (debug) {
            System.out.println("#Points considered: " + union.neighbors.size());
        }

        return union;
    }

    public List<DocVector> getQueries() throws Exception {
        List<DocVector> queryList = null;

        if (syntheticQueries) {
            rqgen.load();
            queryList = rqgen.getQueries();
        } else {
            queryList = indexedVecQueries.getQueries();
        }
        return queryList;
    }

    public void searchWithBenchmarkQueries() throws Exception {
        List<DocVector> queries = getQueries();
        //List<DocVector> queries = vecQueries.getQueries();

        int numQueries = queries.size();
        boolean eval = Boolean.parseBoolean(prop.getProperty("eval", "false"));
        int span = Integer.parseInt(prop.getProperty("match.span", "1"));

        final float maxDist = (float) Math.sqrt(numDimensions) * (DocVector.MAXVAL - DocVector.MINVAL);

        int rAt1 = 0;
        int retrDoc, relDoc = 0;
        float sumDistShift = 0;

        if (end == -1) {
            end = numQueries;
        }
        end = Math.min(end, numQueries);

        for (int i = start; i < end; i++) {
            DocVector qvec = queries.get(i);

            System.out.println("Retrieving for query: " + qvec);
            ANNList anns = subSpaceDimension == 0 ? retrieve(qvec, span) : retrieveFixedProjections(qvec, span);

            List<DocVector> retrDocvecs = anns.selectTop(qvec, reader);
            if (retrDocvecs.get(0) == null)
                continue;

            retrDoc = retrDocvecs.get(0).getId();
            if (eval) {
                relDoc = ((QueryVector) qvec).getNN();
            }

            System.out.println("id(ANN) = " + retrDoc + ", id(NN) = " + relDoc);

            float annDist = (float) Math.sqrt(retrDocvecs.get(0).getDistFromQuery());
            float nnDist = (float) Math.sqrt(((QueryVector) qvec).getNNDist());
            System.out.println("dist(ANN) = " + annDist + ", dist(NN) = " + nnDist);

            float shift = (annDist - nnDist) / maxDist;
            System.out.println("shift = " + shift);
            sumDistShift += shift;

            if (eval) {
                ///*
                int this_r_at_1 = retrDoc == relDoc
                        || retrDocvecs.get(0).getDistFromQuery() == ((QueryVector) qvec).getNNDist() ? 1 : 0;
                rAt1 += this_r_at_1;
                System.out.println("R@1 (" + i + ") = " + this_r_at_1);
                //*/

                //vecQueries.evaluate(reader, i, retrDocvecs);                
            }
        }

        if (eval) {
            System.out.println("R@1 = " + rAt1 / (float) (end - start));
            //System.out.println("R@1 = " + vecQueries.rAt[0]/(float)numQueries);
            //System.out.println("Jacard = " + vecQueries.avgJacard/(float)numQueries);
            System.out.println("Dist margin = " + sumDistShift / (float) (end - start));
        }
    }

    public static void main(String[] args) {
        if (args.length == 0) {
            args = new String[1];
            System.out.println("Usage: java ANNRetriever <prop-file>");
            //args[0] = "init.properties";
            args[0] = "init_synthetic.properties";
        }

        try {
            ANNRetriever searcher = new ANNRetriever(args[0]);
            searcher.searchWithBenchmarkQueries();
            searcher.close();
        } catch (Exception ex) {
            ex.printStackTrace();
        }

    }

}