main.BM25VASimilarity.java Source code

Java tutorial

Introduction

Here is the source code for main.BM25VASimilarity.java

Source

package main;/*
             * Licensed to the Apache Software Foundation (ASF) under one or more
             * contributor license agreements.  See the NOTICE file distributed with
             * this work for additional information regarding copyright ownership.
             * The ASF licenses this file to You under the Apache License, Version 2.0
             * (the "License"); you may not use this file except in compliance with
             * the License.  You may obtain a copy of the License at
             *
             *     http://www.apache.org/licenses/LICENSE-2.0
             *
             * Unless required by applicable law or agreed to in writing, software
             * distributed under the License is distributed on an "AS IS" BASIS,
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
             * See the License for the specific language governing permissions and
             * limitations under the License.
             */

import org.apache.lucene.index.*;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.SmallFloat;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * BM25 Similarity. Introduced in Stephen E. Robertson, Steve Walker,
 * Susan Jones, Micheline Hancock-Beaulieu, and Mike Gatford. Okapi at TREC-3.
 * In Proceedings of the Third <b>T</b>ext <b>RE</b>trieval <b>C</b>onference (TREC 1994).
 * Gaithersburg, USA, November 1994.
 */
public class BM25VASimilarity extends Similarity {
    private final float k1;
    private final float b;
    private final float delta;

    /**
     * BM25 with the supplied parameter values.
     *
     * @param k1 Controls non-linear term frequency normalization (saturation).
     * @param b  Controls to what degree document length normalizes tf values.
     * @throws IllegalArgumentException if {@code k1} is infinite or negative, or if {@code b} is
     *                                  not within the range {@code [0..1]}
     */
    public BM25VASimilarity(float k1, float b, float delta) {
        if (Float.isFinite(k1) == false || k1 < 0) {
            throw new IllegalArgumentException("illegal k1 value: " + k1 + ", must be a non-negative finite value");
        }
        if (Float.isNaN(b) || b < 0 || b > 1) {
            throw new IllegalArgumentException("illegal b value: " + b + ", must be between 0 and 1");
        }
        this.k1 = k1;
        this.b = b;
        this.delta = delta;
    }

    /**
     * BM25 with these default values:
     * <ul>
     * <li>{@code k1 = 1.2}</li>
     * <li>{@code b = 0.75}</li>
     * </ul>
     */
    public BM25VASimilarity() {
        this(1.2f, 0.75f, 0.5f);
    }

    /**
     * Implemented as <code>log(1 + (docCount - docFreq + 0.5)/(docFreq + 0.5))</code>.
     */
    protected float idf(long docFreq, long docCount) {
        return (float) Math.log(1 + (docCount - docFreq + 0.5D) / (docFreq + 0.5D));
    }

    /**
     * Implemented as <code>1 / (distance + 1)</code>.
     */
    protected float sloppyFreq(int distance) {
        return 1.0f / (distance + 1);
    }

    /**
     * The default implementation returns <code>1</code>
     */
    protected float scorePayload(int doc, int start, int end, BytesRef payload) {
        return 1;
    }

    /**
     * The default implementation computes the average as <code>sumTotalTermFreq / docCount</code>,
     * or returns <code>1</code> if the index does not store sumTotalTermFreq:
     * any field that omits frequency information).
     */
    protected float avgFieldLength(CollectionStatistics collectionStats) {
        final long sumTotalTermFreq = collectionStats.sumTotalTermFreq();
        if (sumTotalTermFreq <= 0) {
            return 1f; // field does not exist, or stat is unsupported
        } else {
            final long docCount = collectionStats.docCount() == -1 ? collectionStats.maxDoc()
                    : collectionStats.docCount();
            return (float) (sumTotalTermFreq / (double) docCount);
        }
    }

    /**
     * The default implementation encodes <code>boost / sqrt(length)</code>
     * with {@link SmallFloat#floatToByte315(float)}.  This is compatible with
     * Lucene's default implementation.  If you change this, then you should
     * change {@link #decodeNormValue(byte)} to match.
     */
    protected byte encodeNormValue(float boost, int fieldLength) {
        return SmallFloat.floatToByte315(boost / (float) Math.sqrt(fieldLength));
    }

    /**
     * The default implementation returns <code>1 / f<sup>2</sup></code>
     * where <code>f</code> is {@link SmallFloat#byte315ToFloat(byte)}.
     */
    protected float decodeNormValue(byte b) {
        return NORM_TABLE[b & 0xFF];
    }

    /**
     * True if overlap tokens (tokens with a position of increment of zero) are
     * discounted from the document's length.
     */
    protected boolean discountOverlaps = true;

    /**
     * Sets whether overlap tokens (Tokens with 0 position increment) are
     * ignored when computing norm.  By default this is true, meaning overlap
     * tokens do not count when computing norms.
     */
    public void setDiscountOverlaps(boolean v) {
        discountOverlaps = v;
    }

    /**
     * Returns true if overlap tokens are discounted from the document's length.
     *
     * @see #setDiscountOverlaps
     */
    public boolean getDiscountOverlaps() {
        return discountOverlaps;
    }

    /**
     * Cache of decoded bytes.
     */
    private static final float[] NORM_TABLE = new float[256];

    static {
        for (int i = 1; i < 256; i++) {
            float f = SmallFloat.byte315ToFloat((byte) i);
            NORM_TABLE[i] = 1.0f / (f * f);
        }
        NORM_TABLE[0] = 1.0f / NORM_TABLE[255]; // otherwise inf
    }

    @Override
    public final long computeNorm(FieldInvertState state) {
        final int numTerms = discountOverlaps ? state.getLength() - state.getNumOverlap() : state.getLength();
        state.getUniqueTermCount();
        //System.out.println(state.getName());
        //System.out.println(state.getUniqueTermCount());
        return encodeNormValue(state.getBoost(), numTerms);
    }

    /**
     * Computes a score factor for a simple term and returns an explanation
     * for that score factor.
     * <p>
     * <p>
     * The default implementation uses:
     * <p>
     * <pre class="prettyprint">
     * idf(docFreq, docCount);
     * </pre>
     * <p>
     * Note that {@link CollectionStatistics#docCount()} is used instead of
     * {@link org.apache.lucene.index.IndexReader#numDocs() IndexReader#numDocs()} because also
     * {@link TermStatistics#docFreq()} is used, and when the latter
     * is inaccurate, so is {@link CollectionStatistics#docCount()}, and in the same direction.
     * In addition, {@link CollectionStatistics#docCount()} does not skew when fields are sparse.
     *
     * @param collectionStats collection-level statistics
     * @param termStats       term-level statistics for the term
     * @return an Explain object that includes both an idf score factor
     * and an explanation for the term.
     */
    public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats) {
        final long df = termStats.docFreq();
        final long docCount = collectionStats.docCount() == -1 ? collectionStats.maxDoc()
                : collectionStats.docCount();
        final float idf = idf(df, docCount);
        return Explanation.match(idf, "idf(docFreq=" + df + ", docCount=" + docCount + ")");
    }

    /**
     * Computes a score factor for a phrase.
     * <p>
     * <p>
     * The default implementation sums the idf factor for
     * each term in the phrase.
     *
     * @param collectionStats collection-level statistics
     * @param termStats       term-level statistics for the terms in the phrase
     * @return an Explain object that includes both an idf
     * score factor for the phrase and an explanation
     * for each term.
     */
    public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats[]) {
        final long docCount = collectionStats.docCount() == -1 ? collectionStats.maxDoc()
                : collectionStats.docCount();
        float idf = 0.0f;
        List<Explanation> details = new ArrayList<>();
        for (final TermStatistics stat : termStats) {
            final long df = stat.docFreq();
            final float termIdf = idf(df, docCount);
            details.add(Explanation.match(termIdf, "idf(docFreq=" + df + ", docCount=" + docCount + ")"));
            idf += termIdf;
        }
        return Explanation.match(idf, "idf(), sum of:", details);
    }

    @Override
    public final SimWeight computeWeight(CollectionStatistics collectionStats, TermStatistics... termStats) {
        Explanation idf = termStats.length == 1 ? idfExplain(collectionStats, termStats[0])
                : idfExplain(collectionStats, termStats);

        float avgdl = avgFieldLength(collectionStats);

        // compute freq-independent part of bm25 equation across all norm values
        float cache[] = new float[256];
        for (int i = 0; i < cache.length; i++) {
            //cache[i] = k1 * ((1 - b) + b * decodeNormValue((byte) i) / avgdl);
            //cache becomes cachePrime = B
            //cache[i] = ((1 - b) + b * decodeNormValue((byte) i) / avgdl);
            // B_VA
            //cache[i] = (1/(mavgtf * mavgtf) * decodeNormValue((byte) i) / Td) +
            //        ((1 - 1/mavgtf)*decodeNormValue((byte) i) / avgdl);

            //Cache should now only contain the length of doc d.
            cache[i] = decodeNormValue((byte) i);
        }
        return new BM25Stats(collectionStats.field(), idf, avgdl, cache);
    }

    @Override
    public final SimScorer simScorer(SimWeight stats, LeafReaderContext context) throws IOException {
        BM25Stats bm25stats = (BM25Stats) stats;
        LeafReader reader = context.reader();
        //int docCount = reader.getDocCount(bm25stats.field);

        //BVA calculated for each document
        float[] BVA = new float[reader.maxDoc()];
        float sumOfAverageTermFrequencies = 0.0f;

        //length of each doc
        float[] Ld = new float[reader.maxDoc()];
        //the number of unique terms in the doc.
        float[] Td = new float[reader.maxDoc()];

        NumericDocValues norms = reader.getNormValues(bm25stats.field);

        //        int nulldocs = 0;
        for (int i = 0; i < reader.maxDoc(); i++) {
            Terms terms = reader.getTermVector(i, bm25stats.field);
            //norm should be the decoded length of doc d, Ld.
            float norm = norms == null ? k1 : bm25stats.cache[(byte) norms.get(i) & 0xFF];
            Ld[i] = norm;
            //using terms.size() returns Td, the number of unique terms in the doc.
            Td[i] = terms.size();
            //            if (terms == null) {
            //                nulldocs++;
            //                continue;
            //            }

            float averageTermFrequency = Ld[i] / Td[i];
            sumOfAverageTermFrequencies += averageTermFrequency;
        }
        //calculate mean average term frequency of all documents
        float mavgtf = sumOfAverageTermFrequencies / reader.maxDoc();

        //calculate B_VA for each document
        for (int i = 0; i < reader.maxDoc(); i++) {
            BVA[i] = 1 / (mavgtf * mavgtf) * Ld[i] / Td[i] + (1 - 1 / mavgtf) * Ld[i] / bm25stats.avgdl;
        }

        //        System.out.println("Null docs: "+nulldocs);
        //        System.out.println("Max docs: "+reader.maxDoc());
        //        System.out.println("Doc count: "+reader.getDocCount(bm25stats.field));
        //        System.out.println("max docs minus null docs: "+(reader.maxDoc() - nulldocs));

        return new BM25DocScorer(bm25stats, BVA);
    }

    private class BM25DocScorer extends SimScorer {
        private final BM25Stats stats;
        private final float weightValue; // boost * idf * (k1 + 1)
        private final float[] BVA;
        private final float[] cache;

        BM25DocScorer(BM25Stats stats, float[] BVA) throws IOException {
            this.stats = stats;
            this.weightValue = stats.weight * (k1 + 1);
            this.cache = stats.cache;
            this.BVA = BVA;
        }

        @Override
        public float score(int doc, float freq) {
            // if there are no norms, we act as if b=0
            //float norm = norms == null ? k1 : cache[(byte) norms.get(doc) & 0xFF];
            return weightValue * freq / (freq + k1 * BVA[doc]);
        }

        @Override
        public Explanation explain(int doc, Explanation freq) {
            return explainScore(doc, freq, stats, null);
        }

        @Override
        public float computeSlopFactor(int distance) {
            return sloppyFreq(distance);
        }

        @Override
        public float computePayloadFactor(int doc, int start, int end, BytesRef payload) {
            return scorePayload(doc, start, end, payload);
        }
    }

    /**
     * Collection statistics for the BM25 model.
     */
    private static class BM25Stats extends SimWeight {
        /**
         * BM25's idf
         */
        private final Explanation idf;
        /**
         * The average document length.
         */
        private final float avgdl;
        /**
         * query boost
         */
        private float boost;
        /**
         * weight (idf * boost)
         */
        private float weight;
        /**
         * field name, for pulling norms
         */
        private final String field;
        /**
         * precomputed norm[256] with k1 * ((1 - b) + b * dl / avgdl)
         */
        private final float cache[];

        BM25Stats(String field, Explanation idf, float avgdl, float cache[]) {
            this.field = field;
            this.idf = idf;
            this.avgdl = avgdl;
            this.cache = cache;
            normalize(1f, 1f);
        }

        @Override
        public float getValueForNormalization() {
            // we return a TF-IDF like normalization to be nice, but we don't actually normalize ourselves.
            return weight * weight;
        }

        @Override
        public void normalize(float queryNorm, float boost) {
            // we don't normalize with queryNorm at all, we just capture the top-level boost
            this.boost = boost;
            this.weight = idf.getValue() * boost;
        }
    }

    private Explanation explainTFNorm(int doc, Explanation freq, BM25Stats stats, NumericDocValues norms) {
        List<Explanation> subs = new ArrayList<>();
        subs.add(freq);
        subs.add(Explanation.match(k1, "parameter k1"));
        if (norms == null) {
            subs.add(Explanation.match(0, "parameter b (norms omitted for field)"));
            return Explanation.match((freq.getValue() * (k1 + 1)) / (freq.getValue() + k1),
                    "tfNorm, computed from:", subs);
        } else {
            float doclen = decodeNormValue((byte) norms.get(doc));
            subs.add(Explanation.match(b, "parameter b"));
            subs.add(Explanation.match(stats.avgdl, "avgFieldLength"));
            subs.add(Explanation.match(doclen, "fieldLength"));
            return Explanation.match(
                    (freq.getValue() * (k1 + 1)) / (freq.getValue() + k1 * (1 - b + b * doclen / stats.avgdl)),
                    "tfNorm, computed from:", subs);
        }
    }

    private Explanation explainScore(int doc, Explanation freq, BM25Stats stats, NumericDocValues norms) {
        Explanation boostExpl = Explanation.match(stats.boost, "boost");
        List<Explanation> subs = new ArrayList<>();
        if (boostExpl.getValue() != 1.0f)
            subs.add(boostExpl);
        subs.add(stats.idf);
        Explanation tfNormExpl = explainTFNorm(doc, freq, stats, norms);
        subs.add(tfNormExpl);
        return Explanation.match(boostExpl.getValue() * stats.idf.getValue() * tfNormExpl.getValue(),
                "score(doc=" + doc + ",freq=" + freq + "), product of:", subs);
    }

    @Override
    public String toString() {
        return "BM25(k1=" + k1 + ",b=" + b + ")";
    }

    /**
     * Returns the <code>k1</code> parameter
     *
     * @see #BM25VASimilarity(float, float, float)
     */
    public final float getK1() {
        return k1;
    }

    /**
     * Returns the <code>b</code> parameter
     *
     * @see #BM25VASimilarity(float, float, float)
     */
    public final float getB() {
        return b;
    }

    /**
     * Returns the <code>delta</code> parameter
     *
     * @see #BM25VASimilarity(float, float, float)
     */
    public final float getDelta() {
        return delta;
    }
}