com.o19s.es.ltr.logging.LoggingFetchSubPhaseTests.java Source code

Java tutorial

Introduction

Here is the source code for com.o19s.es.ltr.logging.LoggingFetchSubPhaseTests.java

Source

/*
 * Copyright [2017] Wikimedia Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.o19s.es.ltr.logging;

import com.o19s.es.ltr.LtrTestUtils;
import com.o19s.es.ltr.feature.PrebuiltFeature;
import com.o19s.es.ltr.feature.PrebuiltFeatureSet;
import com.o19s.es.ltr.feature.PrebuiltLtrModel;
import com.o19s.es.ltr.query.RankerQuery;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.ranker.linear.LinearRankerTests;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.fielddata.plain.SortedNumericDVIndexFieldData;
import org.elasticsearch.search.SearchHit;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static java.util.Collections.singletonList;
import static org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction.Modifier.LN2P;
import static org.elasticsearch.index.fielddata.IndexNumericFieldData.NumericType.FLOAT;

public class LoggingFetchSubPhaseTests extends LuceneTestCase {
    public static final float FACTOR = 1.2F;
    private static Directory directory;
    private static IndexSearcher searcher;
    private static Map<String, Document> docs;

    @BeforeClass
    public static void init() throws Exception {
        directory = newDirectory(random());

        try (IndexWriter writer = new IndexWriter(directory, newIndexWriterConfig(new StandardAnalyzer()))) {
            int nDoc = TestUtil.nextInt(random(), 20, 100);
            docs = new HashMap<>();
            for (int i = 0; i < nDoc; i++) {
                Document d = buildDoc(random().nextBoolean() ? "foo" : "bar", random().nextFloat());
                writer.addDocument(d);
                if (random().nextInt(4) == 0) {
                    writer.commit();
                }
                docs.put(d.get("id"), d);
            }
            writer.commit();
        }
        IndexReader reader = closeAfterSuite(DirectoryReader.open(directory));
        searcher = new IndexSearcher(reader);
    }

    @AfterClass
    public static void cleanup() throws IOException {
        try {
            searcher.getIndexReader().close();
        } finally {
            directory.close();
        }
    }

    public void testLogging() throws IOException {
        RankerQuery query1 = buildQuery("foo");
        RankerQuery query2 = buildQuery("bar");
        LoggingFetchSubPhase.HitLogConsumer logger1 = new LoggingFetchSubPhase.HitLogConsumer("logger1",
                query1.featureSet(), true);
        LoggingFetchSubPhase.HitLogConsumer logger2 = new LoggingFetchSubPhase.HitLogConsumer("logger2",
                query2.featureSet(), false);
        query1 = query1.toLoggerQuery(logger1, true);
        query2 = query2.toLoggerQuery(logger2, true);
        BooleanQuery query = new BooleanQuery.Builder().add(new BooleanClause(query1, BooleanClause.Occur.MUST))
                .add(new BooleanClause(query2, BooleanClause.Occur.MUST)).build();
        LoggingFetchSubPhase subPhase = new LoggingFetchSubPhase();
        SearchHit[] hits = selectRandomHits();
        subPhase.doLog(query, Arrays.asList(logger1, logger2), searcher, hits);
        for (SearchHit hit : hits) {
            assertTrue(docs.containsKey(hit.getId()));
            Document d = docs.get(hit.getId());
            assertTrue(hit.getFields().containsKey("_ltrlog"));
            Map<String, List<Map<String, Object>>> logs = hit.getFields().get("_ltrlog").getValue();
            assertTrue(logs.containsKey("logger1"));
            assertTrue(logs.containsKey("logger2"));

            List<Map<String, Object>> log1 = logs.get("logger1");
            List<Map<String, Object>> log2 = logs.get("logger2");
            if (d.get("text").equals("foo")) {
                assertEquals(log1.get(0).get("name"), "text_feat");
                assertFalse(log2.get(0).containsKey("value"));
                //assertNotEquals(log2.get(0).v1(), "text_feat");
                assertTrue((Float) log1.get(0).get("value") > 0F);
            } else {
                //                assertEquals("bar", d.get("text"));
                //
                //
                //                assertTrue(log1.containsKey("text_feat"));
                //                assertTrue(log2.containsKey("text_feat"));
                //                assertTrue(log2.get("text_feat") > 0F);
                //                assertEquals(0F, log1.get("text_feat"), 0F);
            }
            int bits = (int) (long) d.getField("score").numericValue();
            float rawScore = Float.intBitsToFloat(bits);
            double expectedScore = rawScore * FACTOR;
            expectedScore = Math.log1p(expectedScore + 1);
            assertEquals((float) expectedScore, (Float) log1.get(1).get("value"), Math.ulp((float) expectedScore));
            assertEquals((float) expectedScore, (Float) log1.get(1).get("value"), Math.ulp((float) expectedScore));
        }
    }

    public void testBogusQuery() throws IOException {
        PrebuiltFeatureSet set = new PrebuiltFeatureSet("test",
                singletonList(new PrebuiltFeature("test", new BoostQuery(new MatchAllDocsQuery(), Float.NaN))));
        LoggingFetchSubPhase.HitLogConsumer logger1 = new LoggingFetchSubPhase.HitLogConsumer("logger1", set, true);
        RankerQuery q = RankerQuery
                .build(new PrebuiltLtrModel("test", LtrTestUtils.buildRandomRanker(set.size()), set));
        Query lq = q.toLoggerQuery(logger1, true);
        LoggingFetchSubPhase subPhase = new LoggingFetchSubPhase();
        SearchHit[] hits = selectRandomHits();
        expectThrows(LtrLoggingException.class, () -> subPhase.doLog(lq, singletonList(logger1), searcher, hits));
    }

    public SearchHit[] selectRandomHits() throws IOException {
        int minHits = TestUtil.nextInt(random(), 5, 10);
        int maxHits = TestUtil.nextInt(random(), minHits, minHits + 10);
        List<SearchHit> hits = new ArrayList<>(maxHits);
        searcher.search(new MatchAllDocsQuery(), new SimpleCollector() {
            LeafReaderContext context;

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
                super.doSetNextReader(context);
                this.context = context;
            }

            @Override
            public void collect(int doc) throws IOException {
                if (hits.size() < minHits || (random().nextBoolean() && hits.size() < maxHits)) {
                    Document d = context.reader().document(doc);
                    String id = d.get("id");
                    SearchHit hit = new SearchHit(doc + context.docBase, id, new Text("text"),
                            random().nextBoolean() ? new HashMap<>() : null);
                    hits.add(hit);
                }
            }

            @Override
            public boolean needsScores() {
                return false;
            }
        });
        assert hits.size() >= minHits;
        Collections.shuffle(hits, random());
        return hits.toArray(new SearchHit[hits.size()]);
    }

    public static Document buildDoc(String text, float value) throws IOException {
        String id = UUID.randomUUID().toString();
        Document d = new Document();
        d.add(newStringField("id", id, Field.Store.YES));
        d.add(newStringField("text", text, Field.Store.NO));
        d.add(new FloatDocValuesField("score", value));
        return d;
    }

    public RankerQuery buildQuery(String text) {
        List<PrebuiltFeature> features = new ArrayList<>(2);
        features.add(new PrebuiltFeature("text_feat", new TermQuery(new Term("text", text))));
        features.add(new PrebuiltFeature("score_feat", buildFunctionScore()));
        PrebuiltFeatureSet set = new PrebuiltFeatureSet("my_set", features);
        LtrRanker ranker = LinearRankerTests.generateRandomRanker(set.size());
        return RankerQuery.build(new PrebuiltLtrModel("my_model", ranker, set));

    }

    public Query buildFunctionScore() {
        FieldValueFactorFunction fieldValueFactorFunction = new FieldValueFactorFunction("score", FACTOR, LN2P, 0D,
                new SortedNumericDVIndexFieldData(new Index("test", "123"), "score", FLOAT));
        return new FunctionScoreQuery(new MatchAllDocsQuery(), fieldValueFactorFunction, CombineFunction.MULTIPLY,
                0F, Float.MAX_VALUE);
    }
}