org.apache.solr.ltr.TestLTRScoringQuery.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.solr.ltr.TestLTRScoringQuery.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.norm.NormalizerException;
import org.junit.Test;

public class TestLTRScoringQuery extends LuceneTestCase {

    public final static SolrResourceLoader solrResourceLoader = new SolrResourceLoader();

    private IndexSearcher getSearcher(IndexReader r) {
        final IndexSearcher searcher = newSearcher(r, false, false);
        return searcher;
    }

    private static List<Feature> makeFeatures(int[] featureIds) {
        final List<Feature> features = new ArrayList<>();
        for (final int i : featureIds) {
            Map<String, Object> params = new HashMap<String, Object>();
            params.put("value", i);
            final Feature f = Feature.getInstance(solrResourceLoader, ValueFeature.class.getCanonicalName(),
                    "f" + i, params);
            f.setIndex(i);
            features.add(f);
        }
        return features;
    }

    private static List<Feature> makeFilterFeatures(int[] featureIds) {
        final List<Feature> features = new ArrayList<>();
        for (final int i : featureIds) {
            Map<String, Object> params = new HashMap<String, Object>();
            params.put("value", i);
            final Feature f = Feature.getInstance(solrResourceLoader, ValueFeature.class.getCanonicalName(),
                    "f" + i, params);
            f.setIndex(i);
            features.add(f);
        }
        return features;
    }

    private static Map<String, Object> makeFeatureWeights(List<Feature> features) {
        final Map<String, Object> nameParams = new HashMap<String, Object>();
        final HashMap<String, Double> modelWeights = new HashMap<String, Double>();
        for (final Feature feat : features) {
            modelWeights.put(feat.getName(), 0.1);
        }
        nameParams.put("weights", modelWeights);
        return nameParams;
    }

    private LTRScoringQuery.ModelWeight performQuery(TopDocs hits, IndexSearcher searcher, int docid,
            LTRScoringQuery model) throws IOException, ModelException {
        final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext().leaves();
        final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts);
        final LeafReaderContext context = leafContexts.get(n);
        final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase;

        final Weight weight = searcher.createNormalizedWeight(model, true);
        final Scorer scorer = weight.scorer(context);

        // rerank using the field final-score
        scorer.iterator().advance(deBasedDoc);
        scorer.score();

        // assertEquals(42.0f, score, 0.0001);
        // assertTrue(weight instanceof AssertingWeight);
        // (AssertingIndexSearcher)
        assertTrue(weight instanceof LTRScoringQuery.ModelWeight);
        final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight;
        return modelWeight;

    }

    @Test
    public void testLTRScoringQueryEquality() throws ModelException {
        final List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
        final List<Normalizer> norms = new ArrayList<Normalizer>(
                Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
        final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
        final Map<String, Object> modelParams = makeFeatureWeights(features);

        final LTRScoringModel algorithm1 = TestLinearModel.createLinearModel("testModelName", features, norms,
                "testStoreName", allFeatures, modelParams);

        final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1);

        final HashMap<String, String[]> externalFeatureInfo = new HashMap<>();
        externalFeatureInfo.put("queryIntent", new String[] { "company" });
        externalFeatureInfo.put("user_query", new String[] { "abc" });
        final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, false, null);

        final HashMap<String, String[]> externalFeatureInfo2 = new HashMap<>();
        externalFeatureInfo2.put("user_query", new String[] { "abc" });
        externalFeatureInfo2.put("queryIntent", new String[] { "company" });
        int totalPoolThreads = 10, numThreadsPerRequest = 10;
        LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest);
        final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, false, threadManager);

        // Models with same algorithm and efis, just in different order should be the same
        assertEquals(m1, m2);
        assertEquals(m1.hashCode(), m2.hashCode());

        // Models with same algorithm, but different efi content should not match
        assertFalse(m1.equals(m0));
        assertFalse(m1.hashCode() == m0.hashCode());

        final LTRScoringModel algorithm2 = TestLinearModel.createLinearModel("testModelName2", features, norms,
                "testStoreName", allFeatures, modelParams);
        final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2);

        assertFalse(m1.equals(m3));
        assertFalse(m1.hashCode() == m3.hashCode());

        final LTRScoringModel algorithm3 = TestLinearModel.createLinearModel("testModelName", features, norms,
                "testStoreName3", allFeatures, modelParams);
        final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3);

        assertFalse(m1.equals(m4));
        assertFalse(m1.hashCode() == m4.hashCode());
    }

    @Test
    public void testLTRScoringQuery() throws IOException, ModelException {
        final Directory dir = newDirectory();
        final RandomIndexWriter w = new RandomIndexWriter(random(), dir);

        Document doc = new Document();
        doc.add(newStringField("id", "0", Field.Store.YES));
        doc.add(newTextField("field", "wizard the the the the the oz", Field.Store.NO));
        doc.add(new FloatDocValuesField("final-score", 1.0f));

        w.addDocument(doc);
        doc = new Document();
        doc.add(newStringField("id", "1", Field.Store.YES));
        // 1 extra token, but wizard and oz are close;
        doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
        doc.add(new FloatDocValuesField("final-score", 2.0f));
        w.addDocument(doc);

        final IndexReader r = w.getReader();
        w.close();

        // Do ordinary BooleanQuery:
        final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
        bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
        bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
        final IndexSearcher searcher = getSearcher(r);
        // first run the standard query
        final TopDocs hits = searcher.search(bqBuilder.build(), 10);
        assertEquals(2, hits.totalHits);
        assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
        assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));

        List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
        final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
        List<Normalizer> norms = new ArrayList<Normalizer>(
                Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
        LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test",
                allFeatures, makeFeatureWeights(features));

        LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
                new LTRScoringQuery(ltrScoringModel));
        assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length);

        for (int i = 0; i < 3; i++) {
            assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
        }
        int[] posVals = new int[] { 0, 1, 2 };
        int pos = 0;
        for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getFeaturesInfo()) {
            if (fInfo == null) {
                continue;
            }
            assertEquals(posVals[pos], fInfo.getValue(), 0.0001);
            assertEquals("f" + posVals[pos], fInfo.getName());
            pos++;
        }

        final int[] mixPositions = new int[] { 8, 2, 4, 9, 0 };
        features = makeFeatures(mixPositions);
        norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
        ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures,
                makeFeatureWeights(features));

        modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
        assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);

        for (int i = 0; i < mixPositions.length; i++) {
            assertEquals(mixPositions[i], modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
        }

        final ModelException expectedModelException = new ModelException("no features declared for model test");
        final int[] noPositions = new int[] {};
        features = makeFeatures(noPositions);
        norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
        try {
            ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures,
                    makeFeatureWeights(features));
            fail("unexpectedly got here instead of catching " + expectedModelException);
            modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
            assertEquals(0, modelWeight.getModelFeatureWeights().length);
        } catch (ModelException actualModelException) {
            assertEquals(expectedModelException.toString(), actualModelException.toString());
        }

        // test normalizers
        features = makeFilterFeatures(mixPositions);
        final Normalizer norm = new Normalizer() {

            @Override
            public float normalize(float value) {
                return 42.42f;
            }

            @Override
            public LinkedHashMap<String, Object> paramsToMap() {
                return null;
            }

            @Override
            protected void validate() throws NormalizerException {
            }

        };
        norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), norm));
        final LTRScoringModel normMeta = TestLinearModel.createLinearModel("test", features, norms, "test",
                allFeatures, makeFeatureWeights(features));

        modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(normMeta));
        normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized());
        assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);
        for (int i = 0; i < mixPositions.length; i++) {
            assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
        }
        r.close();
        dir.close();

    }

}