org.lenskit.mf.BPR.BPRMFModelProvider.java Source code

Java tutorial

Introduction

Here is the source code for org.lenskit.mf.BPR.BPRMFModelProvider.java

Source

/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2016 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.lenskit.mf.BPR;

import org.apache.commons.math3.linear.*;
import org.grouplens.lenskit.iterative.LearningRate;
import org.grouplens.lenskit.iterative.RegularizationTerm;
import org.grouplens.lenskit.iterative.StoppingCondition;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.lenskit.data.dao.DataAccessObject;
import org.lenskit.data.entities.CommonTypes;
import org.lenskit.data.entities.Entities;
import org.lenskit.data.ratings.PreferenceDomain;
import org.lenskit.data.ratings.RatingMatrix;
import org.lenskit.inject.Transient;
import org.lenskit.mf.funksvd.FeatureCount;
import org.lenskit.mf.svd.MFModel;
import org.lenskit.util.keys.HashKeyIndex;
import org.lenskit.util.keys.KeyIndex;
import org.lenskit.util.math.RollingWindowMeanAccumulator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Provider;

import java.util.Random;

import static java.lang.Math.exp;

/**
 * This class trains a MF Model based on the BPR-Optimization Criteria
 */
public class BPRMFModelProvider implements Provider<MFModel> {

    private static Logger logger = LoggerFactory.getLogger(BPRMFModelProvider.class);

    private final int featureCount;
    private final double learningRate;
    private final double regularization;
    private final DataAccessObject dao;
    private final StoppingCondition stoppingCondition;
    private final TrainingPairGenerator pairGenerator;
    private final Random rand;

    private HashKeyIndex itemIndex;
    private HashKeyIndex userIndex;

    @Inject
    public BPRMFModelProvider(@Transient @Nonnull DataAccessObject dao,
            @Transient @Nonnull TrainingPairGenerator pairGenerator, @FeatureCount int featureCount,
            @LearningRate double learningRate, @RegularizationTerm double regularizaiton,
            @Transient StoppingCondition stoppingCondition, @Transient Random rand) {
        this.dao = dao;
        this.pairGenerator = pairGenerator;
        this.featureCount = featureCount;
        this.learningRate = learningRate;
        this.regularization = regularizaiton;
        this.stoppingCondition = stoppingCondition;
        this.rand = rand;

        itemIndex = new HashKeyIndex();
        userIndex = new HashKeyIndex();
    }

    @Override
    public MFModel get() {
        // This will accumulate BPR-Opt (minus the regularization) and will be negated to make an error.
        // -30 is arbitrary, but would indicate a _really_ consistently bad prediction (~ p=1*10^-13),
        // and is therefore a reasonable "max_error".
        RollingWindowMeanAccumulator optAccum = new RollingWindowMeanAccumulator(10000, -30);

        // set up user index and matrix
        int userCount = dao.getEntityIds(CommonTypes.USER).size();
        for (long uid : dao.getEntityIds(CommonTypes.USER)) {
            userIndex.internId(uid);
        }
        RealMatrix userFeatures = MatrixUtils.createRealMatrix(userCount, featureCount);
        initFeatures(userFeatures);

        // set up item index and matrix
        int itemCount = dao.getEntityIds(CommonTypes.ITEM).size();
        for (long iid : dao.getEntityIds(CommonTypes.ITEM)) {
            itemIndex.internId(iid);
        }
        RealMatrix itemFeatures = MatrixUtils.createRealMatrix(itemCount, featureCount);
        initFeatures(itemFeatures);

        logger.debug("Learning rate is {}", learningRate);
        logger.debug("Regularization term is {}", regularization);

        logger.info("Building MPR-MF with {} features for {} users and {} items", featureCount, userCount,
                itemCount);

        TrainingLoopController controller = stoppingCondition.newLoop();

        //REVIEW: because of the nature of training samples (and the point that the BPR paper makes that training
        // by-item or by-user are not optimal) one "iteration" here will be one training update. This leads to _really_
        // big iteration counts, which can actually overflow ints!. one suggestion would be to allow the iteration count
        // controller to accept longs, but I don't know if this will introduce backwards compatibility issues (I imagine
        // this depends on the robustness of our type conversion in the configuration.
        while (controller.keepTraining(-optAccum.getMean())) {
            for (TrainingItemPair pair : pairGenerator.nextBatch()) {
                // Note: bad code style variable names are generally to match BPR paper and enable easier implementation
                long iid = pair.g;
                int i = itemIndex.internId(iid);
                long jid = pair.l;
                int j = itemIndex.internId(jid);
                long uid = pair.u;
                int u = userIndex.internId(uid);

                RealVector w_u = userFeatures.getRowVector(u);
                RealVector h_i = itemFeatures.getRowVector(i);
                RealVector h_j = itemFeatures.getRowVector(j);

                double xui = w_u.dotProduct(h_i);
                double xuj = w_u.dotProduct(h_j);
                double xuij = xui - xuj;

                double bprTerm = 1 / (1 + exp(xuij));

                // w_u update
                RealVector h_i_j = h_i.subtract(h_j);
                RealVector w_u_update = w_u.mapMultiply(-regularization);
                w_u_update.combineToSelf(1, bprTerm, h_i_j);

                // h_i update
                RealVector h_i_update = h_i.mapMultiply(-regularization);
                h_i_update.combineToSelf(1, bprTerm, w_u);

                // h_j update
                RealVector h_j_update = h_j.mapMultiply(-regularization);
                h_j_update.combineToSelf(1, -bprTerm, w_u);

                // perform updates
                w_u.combineToSelf(1, learningRate, w_u_update);
                h_i.combineToSelf(1, learningRate, h_i_update);
                h_j.combineToSelf(1, learningRate, h_j_update);

                // update the optimization function accumulator (note we are not including the regularization term)
                optAccum.add(Math.log(1 / (1 + Math.exp(-xuij))));
            }
        }

        return new MFModel(userFeatures, itemFeatures, userIndex, itemIndex);
    }

    /**
     * Initialize the feature vectors.
     **/
    private void initFeatures(RealMatrix features) {
        features.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
            @Override
            public double visit(int row, int column, double value) {
                return rand.nextDouble();
            }
        });
    }
}