com.tdunning.scan.Scanner.java Source code

Java tutorial

Introduction

Here is the source code for com.tdunning.scan.Scanner.java

Source

/*
 * Licensed to the Ted Dunning under one or more contributor license
 * agreements.  See the NOTICE file that may be
 * distributed with this work for additional information
 * regarding copyright ownership.  Ted Dunning licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.tdunning.scan;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Queues;

import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.*;

public class Scanner {
    private final long[] bits;
    private final int stride;
    private final int size;
    private int threads = 1;
    private ExecutorService pool;

    public Scanner(int size, int bitsPerEntry) {
        this.size = size;
        stride = (bitsPerEntry + 63) / 64;
        bits = new long[size * stride];
        pool = Executors.newFixedThreadPool(threads);
    }

    public void set(int recordNumber, long[] bits) {
        Preconditions.checkArgument(bits.length == stride);
        int offset = recordNumber * stride;
        for (int i = 0; i < stride; i++) {
            this.bits[offset++] = bits[i];
        }
    }

    public Bits get(int recordNumber) {
        return new Bits(bits, recordNumber * stride, stride);
    }

    public int stride() {
        return stride;
    }

    public void setThreads(int threads) {
        Preconditions.checkArgument(threads > 0, "Must have at least one thread");
        this.threads = threads;
        pool.shutdownNow();
        pool = Executors.newFixedThreadPool(threads);
    }

    public PriorityQueue<Score> scan(int n, long[] q) {

        List<Callable<PriorityQueue<Score>>> tasks = Lists.newArrayList();
        int batches = threads * 2;
        int batchSize = (size + batches - 1) / batches;
        for (int i = 0; i < batches; i++) {
            final int ix = i;
            tasks.add(() -> {
                PriorityQueue<Score> r = Queues.newPriorityQueue();
                int offset = ix * batchSize;
                int worst = stride * 64;
                int ourBatch = Math.min(batchSize, size - offset);
                for (int j = 0; j < ourBatch; j++) {
                    int match = 0;
                    for (int k = 0; k < stride; k++) {
                        match += Long.bitCount(q[k] ^ bits[(j + offset) * stride + k]);
                    }
                    if (match < worst) {
                        r.add(new Score(j, match));
                    }
                    while (r.size() > n) {
                        worst = r.poll().matchingBits;
                    }
                }
                return r;
            });
        }
        try {
            List<Future<PriorityQueue<Score>>> taskResults = pool.invokeAll(tasks);
            PriorityQueue<Score> r = Queues.newPriorityQueue();
            for (Future<PriorityQueue<Score>> rx : taskResults) {
                r.addAll(rx.get());
                while (r.size() > n) {
                    r.poll();
                }
            }
            return r;
        } catch (InterruptedException e) {
            throw new RuntimeException("Aborted execution", e);
        } catch (ExecutionException e) {
            throw new RuntimeException("Error during execution", e);
        }
    }

    public void close() {
        pool.shutdownNow();
    }

    public static class Score implements Comparable<Score> {
        int recordNumber;
        int matchingBits;

        public Score(int recordNumber, int matchingBits) {
            this.recordNumber = recordNumber;
            this.matchingBits = matchingBits;
        }

        /**
         * Sorts by descending number of matched bits and ascending doc number
         * @param o The other thing to compare
         */
        public int compareTo(Score o) {
            int r = o.matchingBits - matchingBits;
            if (r == 0) {
                return recordNumber - o.recordNumber;
            } else {
                return r;
            }
        }
    }

    public static class Bits {
        private final long[] bits;
        private final int offset;
        private final int size;

        public Bits(long[] bits, int offset, int size) {
            this.bits = bits;

            this.offset = offset;
            this.size = size;
        }

        public int compare(long[] bits) {
            int r = 0;
            int i = offset;
            for (int j = 0; j < size; j++) {
                r += Long.bitCount(this.bits[i++] ^ bits[j]);
            }
            return r;
        }
    }
}