de.tuberlin.dima.recsys.ssnmm.ratingprediction.Evaluate.java Source code

Java tutorial

Introduction

Here is the source code for de.tuberlin.dima.recsys.ssnmm.ratingprediction.Evaluate.java

Source

/*
 * Copyright (C) 2012 Sebastian Schelter <sebastian.schelter [at] tu-berlin.de>
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */

package de.tuberlin.dima.recsys.ssnmm.ratingprediction;

import com.google.common.base.Preconditions;
import de.tuberlin.dima.recsys.ssnmm.Rating;
import de.tuberlin.dima.recsys.ssnmm.RatingsIterable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.FileLineIterable;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenIntDoubleHashMap;

import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import java.util.regex.Pattern;

/**
 * Reads the similarity matrix as well as the item and user biases into memory,
 * computes the prediction error towards held out ratings in a single read through the data
 */
public class Evaluate {

    public static void main(String[] args) throws IOException {

        int numUsers = 1823179;
        int numItems = 136736;
        double mu = 3.157255412010664;

        String distributedSimilarityMatrixPath = "/home/ssc/Desktop/yahoo/similarityMatrix/";
        String itemBiasesFilePath = "/home/ssc/Desktop/yahoo/itemBiases.tsv";
        String userBiasesFilePath = "/home/ssc/Desktop/yahoo/userBiases.tsv";
        String trainingSetPath = "/home/ssc/Entwicklung/datasets/yahoo-songs/songs.tsv";
        String holdoutSetPath = "home/ssc/Entwicklung/datasets/yahoo-songs/holdout.tsv";

        Matrix similarities = new SparseRowMatrix(numItems, numItems);

        System.out.println("Reading similarities...");
        int similaritiesRead = 0;
        Configuration conf = new Configuration();
        for (Pair<IntWritable, VectorWritable> pair : new SequenceFileDirIterable<IntWritable, VectorWritable>(
                new Path(distributedSimilarityMatrixPath), PathType.LIST, PathFilters.partFilter(), conf)) {

            int item = pair.getFirst().get();
            Iterator<Vector.Element> elements = pair.getSecond().get().iterateNonZero();

            while (elements.hasNext()) {
                Vector.Element elem = elements.next();
                similarities.setQuick(item, elem.index(), elem.get());
                similaritiesRead++;
            }
        }
        System.out.println("Found " + similaritiesRead + " similarities");

        Pattern sep = Pattern.compile("\t");

        double[] itemBiases = new double[numItems];
        double[] userBiases = new double[numUsers];

        System.out.println("Reading item biases");
        for (String line : new FileLineIterable(new File(itemBiasesFilePath))) {
            String[] parts = sep.split(line);
            itemBiases[Integer.parseInt(parts[0])] = Double.parseDouble(parts[1]);
        }

        System.out.println("Reading user biases");
        for (String line : new FileLineIterable(new File(userBiasesFilePath))) {
            String[] parts = sep.split(line);
            userBiases[Integer.parseInt(parts[0])] = Double.parseDouble(parts[1]);
        }

        Iterator<Rating> trainRatings = new RatingsIterable(new File(trainingSetPath)).iterator();
        Iterator<Rating> heldOutRatings = new RatingsIterable(new File(holdoutSetPath)).iterator();

        int currentUser = 0;
        OpenIntDoubleHashMap prefs = new OpenIntDoubleHashMap();

        int usersProcessed = 0;
        RunningAverage rmse = new FullRunningAverage();
        RunningAverage mae = new FullRunningAverage();

        RunningAverage rmseBase = new FullRunningAverage();
        RunningAverage maeBase = new FullRunningAverage();

        while (trainRatings.hasNext()) {
            Rating rating = trainRatings.next();
            if (rating.user() != currentUser) {

                for (int n = 0; n < 10; n++) {
                    Rating heldOutRating = heldOutRatings.next();
                    Preconditions.checkState(heldOutRating.user() == currentUser);

                    double preference = 0.0;
                    double totalSimilarity = 0.0;
                    int count = 0;

                    Iterator<Vector.Element> similarItems = similarities.viewRow(heldOutRating.item())
                            .iterateNonZero();
                    while (similarItems.hasNext()) {
                        Vector.Element similarity = similarItems.next();
                        int similarItem = similarity.index();
                        if (prefs.containsKey(similarItem)) {
                            preference += similarity.get() * (prefs.get(similarItem)
                                    - (mu + userBiases[currentUser] + itemBiases[similarItem]));
                            totalSimilarity += Math.abs(similarity.get());
                            count++;

                        }
                    }

                    double baselineEstimate = mu + userBiases[currentUser] + itemBiases[heldOutRating.item()];
                    double estimate = baselineEstimate;

                    if (count > 1) {
                        estimate += preference / totalSimilarity;
                    }

                    double baseError = Math.abs(heldOutRating.rating() - baselineEstimate);
                    maeBase.addDatum(baseError);
                    rmseBase.addDatum(baseError * baseError);

                    double error = Math.abs(heldOutRating.rating() - estimate);
                    mae.addDatum(error);
                    rmse.addDatum(error * error);

                }

                if (++usersProcessed % 10000 == 0) {
                    System.out.println(usersProcessed + " users processed, MAE " + mae.getAverage() + ", RMSE "
                            + Math.sqrt(rmse.getAverage()) + " | baseline MAE " + maeBase.getAverage()
                            + ", baseline RMSE " + Math.sqrt(rmseBase.getAverage()));
                }

                currentUser = rating.user();
                prefs.clear();

            }
            prefs.put(rating.item(), rating.rating());

        }

        System.out.println(usersProcessed + " users processed, MAE " + mae.getAverage() + ", RMSE "
                + Math.sqrt(rmse.getAverage()) + " | baseline MAE " + maeBase.getAverage() + ", baseline RMSE "
                + Math.sqrt(rmseBase.getAverage()));
    }
}