org.grouplens.lenskit.eval.traintest.TrainTestEvalJob.java Source code

Java tutorial

Introduction

Here is the source code for org.grouplens.lenskit.eval.traintest.TrainTestEvalJob.java

Source

/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2013 Regents of the University of Minnesota and contributors
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.eval.ExecutionInfo;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.RecommenderInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.inject.Provider;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Run a single train-test evaluation of a single algorithm.
 *
 * @author <a href="http://www.grouplens.org">GroupLens Research</a>
 * @since 0.8
 */
class TrainTestEvalJob implements Runnable {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalJob.class);

    private final int numRecs;

    @Nonnull
    private final AlgorithmInstance algorithm;
    @Nonnull
    private final List<TestUserMetric> evaluators;
    @Nonnull
    private final List<ModelMetric> modelMetrics;
    @Nonnull
    private final List<Pair<Symbol, String>> channels;
    @Nonnull
    private final TTDataSet data;
    @Nonnull
    private final Supplier<TableWriter> outputSupplier;
    @Nonnull
    private final Supplier<TableWriter> userOutputSupplier;
    @Nonnull
    private final Supplier<TableWriter> predictOutputSupplier;
    private final Provider<PreferenceSnapshot> snapshot;

    /**
     * Create a new train-test eval job.
     *
     * @param algo    The algorithm to test.
     * @param evals   The evaluators to use.
     * @param chans   The list of channels to extract.
     * @param ds      The data set to use.
     * @param snap    Supplier providing access to a shared rating snapshot to use in the
     *                build process.
     * @param out     The table writer to receive output. This writer is expected to
     *                be prefixed with algorithm and group ID data, so only the times
     *                and eval outputProvider needs to be written.
     * @param userOut The table writer to receive user output. The supplier may return
     *                {@code null}.
     * @param predOut The table writer to receive prediction output. The supplier may
     *                return {@code null}.
     * @param nrecs The number of recommendations to compute.
     */
    public TrainTestEvalJob(@Nonnull AlgorithmInstance algo, @Nonnull List<TestUserMetric> evals,
            @Nonnull List<ModelMetric> mMetrics, @Nonnull List<Pair<Symbol, String>> chans, @Nonnull TTDataSet ds,
            Provider<PreferenceSnapshot> snap, @Nonnull Supplier<TableWriter> out,
            @Nonnull Supplier<TableWriter> userOut, @Nonnull Supplier<TableWriter> predOut, int nrecs) {
        algorithm = algo;
        evaluators = evals;
        modelMetrics = mMetrics;
        channels = chans;
        data = ds;
        snapshot = snap;
        outputSupplier = out;
        userOutputSupplier = userOut;
        predictOutputSupplier = predOut;
        numRecs = nrecs;
    }

    @Override
    public void run() {
        try {
            runEvaluation();
        } catch (Exception e) {
            throw new TrainTestJobException(e);
        }
    }

    @SuppressWarnings("PMD.AvoidCatchingThrowable")
    private void runEvaluation() throws IOException, RecommenderBuildException {
        Closer closer = Closer.create();
        try {
            TableWriter userTable = userOutputSupplier.get();
            if (userTable != null) {
                closer.register(userTable);
            }
            TableWriter predictTable = predictOutputSupplier.get();
            if (predictTable != null) {
                closer.register(predictTable);
            }

            List<Object> outputRow = Lists.newArrayList();

            ExecutionInfo execInfo = buildExecInfo();

            logger.info("Building {}", algorithm.getName());
            StopWatch buildTimer = new StopWatch();
            buildTimer.start();
            RecommenderInstance rec = algorithm.makeTestableRecommender(data, snapshot, execInfo);
            buildTimer.stop();
            logger.info("Built {} in {}", algorithm.getName(), buildTimer);

            logger.info("Measuring {}", algorithm.getName());
            for (ModelMetric metric : modelMetrics) {
                outputRow.addAll(metric.measureAlgorithm(algorithm, data, rec.getRecommender()));
            }

            logger.info("Testing {}", algorithm.getName());
            StopWatch testTimer = new StopWatch();
            testTimer.start();
            List<TestUserMetricAccumulator> evalAccums = new ArrayList<TestUserMetricAccumulator>(
                    evaluators.size());

            List<Object> userRow = new ArrayList<Object>();

            UserEventDAO testUsers = data.getTestData().getUserEventDAO();
            for (TestUserMetric eval : evaluators) {
                TestUserMetricAccumulator accum = eval.makeAccumulator(algorithm, data);
                evalAccums.add(accum);
            }

            Cursor<UserHistory<Event>> userProfiles = closer.register(testUsers.streamEventsByUser());
            for (UserHistory<Event> p : userProfiles) {
                assert userRow.isEmpty();
                userRow.add(p.getUserId());

                long uid = p.getUserId();
                LongSet testItems = p.itemSet();

                Supplier<SparseVector> preds = new PredictionSupplier(rec, uid, testItems);
                Supplier<List<ScoredId>> recs = new RecommendationSupplier(rec, uid, testItems);
                Supplier<UserHistory<Event>> hist = new HistorySupplier(rec.getUserEventDAO(), uid);
                Supplier<UserHistory<Event>> testHist = Suppliers.ofInstance(p);

                TestUser test = new TestUser(uid, hist, testHist, preds, recs);

                for (TestUserMetricAccumulator accum : evalAccums) {
                    Object[] ures = accum.evaluate(test);
                    if (ures != null) {
                        userRow.addAll(Arrays.asList(ures));
                    }
                }
                if (userTable != null) {
                    try {
                        userTable.writeRow(userRow);
                    } catch (IOException e) {
                        throw new RuntimeException("error writing user row", e);
                    }
                }
                userRow.clear();

                if (predictTable != null) {
                    writePredictions(predictTable, uid, RatingVectorUserHistorySummarizer.makeRatingVector(p),
                            test.getPredictions());
                }
            }
            testTimer.stop();
            logger.info("Tested {} in {}", algorithm.getName(), testTimer);

            writeOutput(buildTimer, testTimer, outputRow, evalAccums);
        } catch (Throwable th) {
            throw closer.rethrow(th, RecommenderBuildException.class);
        } finally {
            closer.close();
        }
    }

    private ExecutionInfo buildExecInfo() {
        ExecutionInfo.Builder bld = new ExecutionInfo.Builder();
        bld.setAlgoName(algorithm.getName()).setAlgoAttributes(algorithm.getAttributes())
                .setDataName(data.getName()).setDataAttributes(data.getAttributes());
        return bld.build();
    }

    private void writePredictions(TableWriter predictTable, long uid, SparseVector ratings,
            SparseVector predictions) throws IOException {
        final int ncols = predictTable.getLayout().getColumnCount();
        final String[] row = new String[ncols];
        row[0] = Long.toString(uid);
        for (VectorEntry e : ratings.fast()) {
            long iid = e.getKey();
            row[1] = Long.toString(iid);
            row[2] = Double.toString(e.getValue());
            if (predictions.containsKey(iid)) {
                row[3] = Double.toString(predictions.get(iid));
            } else {
                row[3] = null;
            }
            int i = 4;
            for (Pair<Symbol, String> pair : channels) {
                Symbol c = pair.getLeft();
                if (predictions.hasChannelVector(c) && predictions.getChannelVector(c).containsKey(iid)) {
                    row[i] = Double.toString(predictions.getChannelVector(c).get(iid));
                } else {
                    row[i] = null;
                }
                i += 1;
            }
            predictTable.writeRow(row);
        }
    }

    private void writeOutput(StopWatch build, StopWatch test, List<Object> measures,
            List<TestUserMetricAccumulator> accums) throws IOException {
        TableWriter output = outputSupplier.get();

        try {
            Object[] row = new Object[output.getLayout().getColumnCount()];
            row[0] = build.getTime();
            row[1] = test.getTime();
            int col = 2;
            for (Object o : measures) {
                row[col] = o;
                col += 1;
            }
            for (TestUserMetricAccumulator acc : accums) {
                Object[] ar = acc.finalResults();
                if (ar != null) {
                    // no aggregated output is generated
                    int n = ar.length;
                    System.arraycopy(ar, 0, row, col, n);
                    col += n;
                }
            }
            output.writeRow(row);
        } finally {
            output.close();
        }
    }

    private class PredictionSupplier implements Supplier<SparseVector> {
        private final RecommenderInstance predictor;
        private final long user;
        private final LongSet items;

        public PredictionSupplier(RecommenderInstance pred, long id, LongSet is) {
            predictor = pred;
            user = id;
            items = is;
        }

        @Override
        public SparseVector get() {
            if (predictor == null) {
                throw new IllegalArgumentException("cannot compute predictions without a predictor");
            }
            SparseVector preds = predictor.getPredictions(user, items);
            if (preds == null) {
                throw new IllegalArgumentException("no predictions");
            }
            return preds;
        }
    }

    private class RecommendationSupplier implements Supplier<List<ScoredId>> {
        private final RecommenderInstance recommender;
        private final long user;
        private final LongSet items;

        public RecommendationSupplier(RecommenderInstance rec, long id, LongSet is) {
            recommender = rec;
            user = id;
            items = is;
        }

        @Override
        public List<ScoredId> get() {
            if (recommender == null) {
                throw new IllegalArgumentException("cannot compute recommendations without a recommender");
            }
            List<ScoredId> recs = recommender.getRecommendations(user, items, numRecs);
            if (recs == null) {
                throw new IllegalArgumentException("no recommendations");
            }
            return recs;
        }
    }

    private class HistorySupplier implements Supplier<UserHistory<Event>> {
        private final UserEventDAO userEventDAO;
        private final long user;

        public HistorySupplier(UserEventDAO dao, long id) {
            userEventDAO = dao;
            user = id;
        }

        @Override
        public UserHistory<Event> get() {
            return userEventDAO.getEventsForUser(user);
        }
    }
}