com.cloudera.oryx.app.serving.als.model.ALSServingModelTest.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.app.serving.als.model.ALSServingModelTest.java

Source

/*
 * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.app.serving.als.model;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.oryx.app.serving.als.CosineAverageFunction;
import com.cloudera.oryx.app.serving.als.DotsFunction;
import com.cloudera.oryx.common.OryxTest;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.math.VectorMath;
import com.cloudera.oryx.common.random.RandomManager;

public final class ALSServingModelTest extends OryxTest {

    private static final Logger log = LoggerFactory.getLogger(ALSServingModelTest.class);

    @Test
    public void testUserItemVector() {
        ALSServingModel model = new ALSServingModel(2, true, 1.0, null);
        assertEquals(2, model.getFeatures());
        assertTrue(model.isImplicit());
        assertNull(model.getRescorerProvider());
        model.setUserVector("U1", new float[] { 1.5f, -2.5f });
        assertArrayEquals(new float[] { 1.5f, -2.5f }, model.getUserVector("U1"));
        model.setItemVector("I0", new float[] { 0.5f, 0.0f });
        assertArrayEquals(new float[] { 0.5f, 0.0f }, model.getItemVector("I0"));
        assertContainsSame(Arrays.asList("U1"), model.getAllUserIDs());
        assertContainsSame(Arrays.asList("I0"), model.getAllItemIDs());
    }

    @Test
    public void testKnownItems() {
        ALSServingModel model = new ALSServingModel(2, true, 1.0, null);
        populateKnownItems(model);
        assertContainsSame(Arrays.asList("I0", "I1"), model.getKnownItems("U0"));
        assertContainsSame(Arrays.asList("I0", "I1", "I2"), model.getKnownItems("U1"));
        assertContainsSame(Arrays.asList("I8", "I9"), model.getKnownItems("U9"));
        Map<String, Integer> userCounts = model.getUserCounts();
        assertEquals(2, userCounts.get("U0").intValue());
        assertEquals(3, userCounts.get("U1").intValue());
        assertEquals(2, userCounts.get("U9").intValue());
        Map<String, Integer> itemCounts = model.getItemCounts();
        assertEquals(2, itemCounts.get("I0").intValue());
        assertEquals(3, itemCounts.get("I1").intValue());
        assertEquals(2, itemCounts.get("I9").intValue());
    }

    @Test
    public void testRetainUsersItems() {
        ALSServingModel model = new ALSServingModel(2, true, 1.0, null);

        model.setUserVector("U0", new float[] { 1.0f, 1.0f });
        model.retainRecentAndUserIDs(Collections.emptyList());
        // Protected because of recent user/items
        assertNotNull(model.getUserVector("U0"));
        model.retainRecentAndUserIDs(Collections.emptyList());
        assertNull(model.getUserVector("U0"));

        model.setUserVector("U0", new float[] { 1.0f, 1.0f });
        model.retainRecentAndUserIDs(Arrays.asList("U0"));
        assertNotNull(model.getUserVector("U0"));
        model.retainRecentAndUserIDs(Arrays.asList("U0"));
        assertNotNull(model.getUserVector("U0"));

        model.setItemVector("I0", new float[] { 1.0f, 1.0f });
        model.retainRecentAndItemIDs(Collections.emptyList());
        // Protected because of recent user/items
        assertNotNull(model.getItemVector("I0"));
        model.retainRecentAndItemIDs(Collections.emptyList());
        assertNull(model.getItemVector("I0"));

        model.setItemVector("I0", new float[] { 1.0f, 1.0f });
        model.retainRecentAndItemIDs(Arrays.asList("I0"));
        assertNotNull(model.getItemVector("I0"));
        model.retainRecentAndItemIDs(Arrays.asList("I0"));
        assertNotNull(model.getItemVector("I0"));
    }

    @Test
    public void testRetainKnown() {
        ALSServingModel model = new ALSServingModel(2, true, 1.0, null);
        populateKnownItems(model);
        for (int i = 0; i < 10; i++) {
            model.setUserVector("U" + i, new float[] { 0.0f, 0.0f });
            model.setItemVector("I" + i, new float[] { 0.0f, 0.0f });
        }
        model.retainRecentAndKnownItems(Arrays.asList("U4", "U5", "U6"), Arrays.asList("I4", "I5", "I6"));
        assertContains(model.getKnownItems("U3"), "I4");
        assertContains(model.getKnownItems("U4"), "I4");
        assertContains(model.getKnownItems("U6"), "I6");
        assertContains(model.getKnownItems("U6"), "I7");
        // Protected because of recent user/items
        assertContains(model.getKnownItems("U2"), "I2");

        // Clears recent user/items
        model.retainRecentAndUserIDs(Collections.emptyList());
        model.retainRecentAndItemIDs(Collections.emptyList());

        model.retainRecentAndKnownItems(Arrays.asList("U4", "U5", "U6"), Arrays.asList("I4", "I5", "I6"));
        assertEquals(0, model.getKnownItems("U3").size());
        assertContains(model.getKnownItems("U4"), "I4");
        assertContains(model.getKnownItems("U6"), "I6");
        assertNotContains(model.getKnownItems("U6"), "I7");
        assertEquals(0, model.getKnownItems("U2").size());
    }

    @Test
    public void testToString() {
        String modelToString = new ALSServingModel(2, true, 1.0, null).toString();
        assertContains(modelToString, "ALSServingModel");
        assertContains(modelToString, "features:2");
        assertContains(modelToString, "implicit:true");
    }

    private static void populateKnownItems(ALSServingModel model) {
        for (int i = 0; i < 10; i++) {
            String userID = "U" + i;
            for (int j = 0; j < 10; j++) {
                if (Math.abs(i - j) <= 1) {
                    String itemID = "I" + j;
                    model.addKnownItems(userID, Collections.singleton(itemID));
                }
            }
        }
    }

    @Test
    public void testLSHEffect() {
        RandomGenerator random = RandomManager.getRandom();
        PoissonDistribution itemPerUserDist = new PoissonDistribution(random, 20,
                PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
        int features = 20;
        ALSServingModel mainModel = new ALSServingModel(features, true, 1.0, null);
        ALSServingModel lshModel = new ALSServingModel(features, true, 0.5, null);

        int userItemCount = 20000;
        for (int user = 0; user < userItemCount; user++) {
            String userID = "U" + user;
            float[] vec = VectorMath.randomVectorF(features, random);
            mainModel.setUserVector(userID, vec);
            lshModel.setUserVector(userID, vec);
            int itemsPerUser = itemPerUserDist.sample();
            Collection<String> knownIDs = new ArrayList<>(itemsPerUser);
            for (int i = 0; i < itemsPerUser; i++) {
                knownIDs.add("I" + random.nextInt(userItemCount));
            }
            mainModel.addKnownItems(userID, knownIDs);
            lshModel.addKnownItems(userID, knownIDs);
        }

        for (int item = 0; item < userItemCount; item++) {
            String itemID = "I" + item;
            float[] vec = VectorMath.randomVectorF(features, random);
            mainModel.setItemVector(itemID, vec);
            lshModel.setItemVector(itemID, vec);
        }

        int numRecs = 10;
        Mean meanMatchLength = new Mean();
        for (int user = 0; user < userItemCount; user++) {
            String userID = "U" + user;
            List<Pair<String, Double>> mainRecs = mainModel
                    .topN(new DotsFunction(mainModel.getUserVector(userID)), null, numRecs, null)
                    .collect(Collectors.toList());
            List<Pair<String, Double>> lshRecs = lshModel
                    .topN(new DotsFunction(lshModel.getUserVector(userID)), null, numRecs, null)
                    .collect(Collectors.toList());
            int i = 0;
            while (i < lshRecs.size() && i < mainRecs.size() && lshRecs.get(i).equals(mainRecs.get(i))) {
                i++;
            }
            meanMatchLength.increment(i);
        }
        log.info("Mean matching prefix: {}", meanMatchLength.getResult());
        assertGreaterOrEqual(meanMatchLength.getResult(), 4.0);

        meanMatchLength.clear();
        for (int item = 0; item < userItemCount; item++) {
            String itemID = "I" + item;
            List<Pair<String, Double>> mainRecs = mainModel
                    .topN(new CosineAverageFunction(mainModel.getItemVector(itemID)), null, numRecs, null)
                    .collect(Collectors.toList());
            List<Pair<String, Double>> lshRecs = lshModel
                    .topN(new CosineAverageFunction(lshModel.getItemVector(itemID)), null, numRecs, null)
                    .collect(Collectors.toList());
            int i = 0;
            while (i < lshRecs.size() && i < mainRecs.size() && lshRecs.get(i).equals(mainRecs.get(i))) {
                i++;
            }
            meanMatchLength.increment(i);
        }
        log.info("Mean matching prefix: {}", meanMatchLength.getResult());
        assertGreaterOrEqual(meanMatchLength.getResult(), 5.0);
    }

    @Test
    public void testFractionLoaded() {
        assertEquals(1.0f, new ALSServingModel(2, true, 1.0, null).getFractionLoaded());
        ALSServingModel model = new ALSServingModel(2, true, 1.0, null);
        assertNotNull(model.toString());
        model.retainRecentAndUserIDs(Collections.singleton("U1"));
        model.retainRecentAndItemIDs(Collections.singleton("I0"));
        assertEquals(0.0f, model.getFractionLoaded());
        model.setUserVector("U1", new float[] { 1.5f, -2.5f });
        assertEquals(0.5f, model.getFractionLoaded());
        model.setItemVector("I0", new float[] { 0.5f, 0.0f });
        assertEquals(1.0f, model.getFractionLoaded());
    }

}