org.apache.solr.ltr.TestSelectiveWeightCreation.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.solr.ltr.TestSelectiveWeightCreation.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

public class TestSelectiveWeightCreation extends TestRerankBase {
    private IndexSearcher getSearcher(IndexReader r) {
        final IndexSearcher searcher = newSearcher(r, false, false);
        return searcher;
    }

    private static List<Feature> makeFeatures(int[] featureIds) {
        final List<Feature> features = new ArrayList<>();
        for (final int i : featureIds) {
            Map<String, Object> params = new HashMap<String, Object>();
            params.put("value", i);
            final Feature f = Feature.getInstance(solrResourceLoader, ValueFeature.class.getCanonicalName(),
                    "f" + i, params);
            f.setIndex(i);
            features.add(f);
        }
        return features;
    }

    private static Map<String, Object> makeFeatureWeights(List<Feature> features) {
        final Map<String, Object> nameParams = new HashMap<String, Object>();
        final HashMap<String, Double> modelWeights = new HashMap<String, Double>();
        for (final Feature feat : features) {
            modelWeights.put(feat.getName(), 0.1);
        }
        nameParams.put("weights", modelWeights);
        return nameParams;
    }

    private LTRScoringQuery.ModelWeight performQuery(TopDocs hits, IndexSearcher searcher, int docid,
            LTRScoringQuery model) throws IOException, ModelException {
        final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext().leaves();
        final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts);
        final LeafReaderContext context = leafContexts.get(n);
        final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase;

        final Weight weight = searcher.createNormalizedWeight(model, true);
        final Scorer scorer = weight.scorer(context);

        // rerank using the field final-score
        scorer.iterator().advance(deBasedDoc);
        scorer.score();
        assertTrue(weight instanceof LTRScoringQuery.ModelWeight);
        final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight;
        return modelWeight;

    }

    @BeforeClass
    public static void before() throws Exception {
        setuptest(false);

        assertU(adoc("id", "1", "title", "w1 w3", "description", "w1", "popularity", "1"));
        assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", "2"));
        assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", "3"));
        assertU(adoc("id", "4", "title", "w4 w3", "description", "w4", "popularity", "4"));
        assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", "5"));
        assertU(commit());

        loadFeatures("external_features.json");
        loadModels("external_model.json");
        loadModels("external_model_store.json");
    }

    @AfterClass
    public static void after() throws Exception {
        aftertest();
    }

    @Test
    public void testScoringQueryWeightCreation() throws IOException, ModelException {
        final Directory dir = newDirectory();
        final RandomIndexWriter w = new RandomIndexWriter(random(), dir);

        Document doc = new Document();
        doc.add(newStringField("id", "0", Field.Store.YES));
        doc.add(newTextField("field", "wizard the the the the the oz", Field.Store.NO));
        doc.add(new FloatDocValuesField("final-score", 1.0f));

        w.addDocument(doc);
        doc = new Document();
        doc.add(newStringField("id", "1", Field.Store.YES));
        // 1 extra token, but wizard and oz are close;
        doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
        doc.add(new FloatDocValuesField("final-score", 2.0f));
        w.addDocument(doc);

        final IndexReader r = w.getReader();
        w.close();

        // Do ordinary BooleanQuery:
        final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
        bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
        bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
        final IndexSearcher searcher = getSearcher(r);
        // first run the standard query
        final TopDocs hits = searcher.search(bqBuilder.build(), 10);
        assertEquals(2, hits.totalHits);
        assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
        assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));

        List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
        final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
        final List<Normalizer> norms = new ArrayList<>();
        for (int k = 0; k < features.size(); ++k) {
            norms.add(IdentityNormalizer.INSTANCE);
        }

        // when features are NOT requested in the response, only the modelFeature weights should be created
        final LTRScoringModel ltrScoringModel1 = TestLinearModel.createLinearModel("test", features, norms, "test",
                allFeatures, makeFeatureWeights(features));
        LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
                new LTRScoringQuery(ltrScoringModel1, false)); // features not requested in response
        LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo();

        assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
        int validFeatures = 0;
        for (int i = 0; i < featuresInfo.length; ++i) {
            if (featuresInfo[i] != null && featuresInfo[i].isUsed()) {
                validFeatures += 1;
            }
        }
        assertEquals(validFeatures, features.size());

        // when features are requested in the response, weights should be created for all features
        final LTRScoringModel ltrScoringModel2 = TestLinearModel.createLinearModel("test", features, norms, "test",
                allFeatures, makeFeatureWeights(features));
        modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
                new LTRScoringQuery(ltrScoringModel2, true)); // features requested in response
        featuresInfo = modelWeight.getFeaturesInfo();

        assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
        assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length);

        validFeatures = 0;
        for (int i = 0; i < featuresInfo.length; ++i) {
            if (featuresInfo[i] != null && featuresInfo[i].isUsed()) {
                validFeatures += 1;
            }
        }
        assertEquals(validFeatures, allFeatures.size());

        assertU(delI("0"));
        assertU(delI("1"));
        r.close();
        dir.close();
    }

    @Test
    public void testSelectiveWeightsRequestFeaturesFromDifferentStore() throws Exception {

        final String docs0fv = FeatureLoggerTestUtils.toFeatureVector("matchedTitle", "1.0", "titlePhraseMatch",
                "0.40254828");
        final String docs0fv_fstore4 = FeatureLoggerTestUtils.toFeatureVector("popularity", "3.0", "originalScore",
                "1.0");

        final SolrQuery query = new SolrQuery();
        query.setQuery("*:*");
        query.add("fl", "*,score");
        query.add("rows", "4");
        query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}");
        query.add("fl", "fv:[fv]");

        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='4'");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='" + docs0fv + "'"); // extract all features in default store

        query.remove("fl");
        query.remove("rq");
        query.add("fl", "*,score");
        query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}");
        query.add("fl", "fv:[fv store=fstore4 efi.myPop=3]");

        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.999");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='" + docs0fv_fstore4 + "'"); // extract all features from fstore4

        query.remove("fl");
        query.remove("rq");
        query.add("fl", "*,score");
        query.add("rq", "{!ltr reRankDocs=4 model=externalmodelstore efi.user_query=w3 efi.myconf=0.8}");
        query.add("fl", "fv:[fv store=fstore4 efi.myPop=3]");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); // score using fstore2 used by externalmodelstore
        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.7992");
        assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='" + docs0fv_fstore4 + "'"); // extract all features from fstore4
    }
}