org.lenskit.eval.traintest.ExperimentJob.java Source code

Java tutorial

Introduction

Here is the source code for org.lenskit.eval.traintest.ExperimentJob.java

Source

/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.lenskit.eval.traintest;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.grouplens.grapht.Component;
import org.grouplens.grapht.Dependency;
import org.grouplens.grapht.InjectionException;
import org.grouplens.grapht.graph.DAGNode;
import org.grouplens.grapht.graph.MergePool;
import org.lenskit.LenskitConfiguration;
import org.lenskit.LenskitRecommender;
import org.lenskit.api.RecommenderBuildException;
import org.lenskit.data.dao.ItemDAO;
import org.lenskit.data.dao.ItemListItemDAO;
import org.lenskit.data.dao.UserEventDAO;
import org.lenskit.data.events.Event;
import org.lenskit.data.history.History;
import org.lenskit.data.history.UserHistory;
import org.lenskit.inject.GraphtUtils;
import org.lenskit.inject.NodeProcessors;
import org.lenskit.inject.RecommenderInstantiator;
import org.lenskit.util.ProgressLogger;
import org.lenskit.util.UncheckedInterruptException;
import org.lenskit.util.table.RowBuilder;
import org.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;

/**
 * Individual job evaluating a single experimental condition.
 */
class ExperimentJob extends RecursiveAction {
    private static final Logger logger = LoggerFactory.getLogger(ExperimentJob.class);

    private final TrainTestExperiment experiment;
    private final AlgorithmInstance algorithm;
    private final DataSet dataSet;
    private final LenskitConfiguration sharedConfig;

    @Nullable
    private final ComponentCache cache;
    private final MergePool<Component, Dependency> mergePool;

    ExperimentJob(TrainTestExperiment exp, @Nonnull AlgorithmInstance algo, @Nonnull DataSet ds,
            LenskitConfiguration shared, @Nullable ComponentCache cache,
            @Nullable MergePool<Component, Dependency> pool) {
        experiment = exp;
        algorithm = algo;
        dataSet = ds;
        sharedConfig = shared;
        this.cache = cache;
        mergePool = pool;
    }

    @Override
    protected void compute() {
        ExperimentOutputLayout layout = experiment.getOutputLayout();
        TableWriter globalOutput = layout.prefixTable(experiment.getGlobalOutput(), dataSet, algorithm);
        TableWriter userOutput = layout.prefixTable(experiment.getUserOutput(), dataSet, algorithm);
        RowBuilder outputRow = globalOutput.getLayout().newRowBuilder();

        logger.info("Building {} on {}", algorithm, dataSet);
        Stopwatch buildTimer = Stopwatch.createStarted();
        try (LenskitRecommender rec = buildRecommender()) {
            buildTimer.stop();
            logger.info("Built {} in {}", algorithm.getName(), buildTimer);

            logger.info("Measuring {} on {}", algorithm.getName(), dataSet.getName());

            RowBuilder userRow = userOutput != null ? userOutput.getLayout().newRowBuilder() : null;

            Stopwatch testTimer = Stopwatch.createStarted();

            List<ConditionEvaluator> accumulators = Lists.newArrayList();

            for (EvalTask task : experiment.getTasks()) {
                ConditionEvaluator ce = task.createConditionEvaluator(algorithm, dataSet, rec);
                if (ce != null) {
                    accumulators.add(ce);
                } else {
                    logger.warn("Could not instantiate task {} for algorithm {} on data set {}", task, algorithm,
                            dataSet);
                }
            }

            LongSet testUsers = dataSet.getTestData().getUserDAO().getUserIds();
            UserEventDAO trainEvents = dataSet.getTrainingData().getUserEventDAO();
            UserEventDAO userEvents = dataSet.getTestData().getUserEventDAO();
            final NumberFormat pctFormat = NumberFormat.getPercentInstance();
            pctFormat.setMaximumFractionDigits(2);
            pctFormat.setMinimumFractionDigits(2);
            final int nusers = testUsers.size();
            logger.info("Testing {} on {} ({} users)", algorithm, dataSet, nusers);
            ProgressLogger progress = ProgressLogger.create(logger).setCount(nusers).setLabel("testing users")
                    .start();
            for (LongIterator iter = testUsers.iterator(); iter.hasNext();) {
                if (Thread.interrupted()) {
                    throw new EvaluationException("eval job interrupted");
                }
                long uid = iter.nextLong();
                if (userRow != null) {
                    userRow.add("User", uid);
                }

                UserHistory<Event> trainData = trainEvents.getEventsForUser(uid);
                if (trainData == null) {
                    trainData = History.forUser(uid);
                }
                UserHistory<Event> userData = userEvents.getEventsForUser(uid);
                TestUser user = new TestUser(trainData, userData);

                Stopwatch userTimer = Stopwatch.createStarted();

                for (ConditionEvaluator eval : accumulators) {
                    Map<String, Object> ures = eval.measureUser(user);
                    if (userRow != null) {
                        userRow.addAll(ures);
                    }
                }
                userTimer.stop();
                if (userRow != null) {
                    userRow.add("TestTime", userTimer.elapsed(TimeUnit.MILLISECONDS) * 0.001);
                    assert userOutput != null;
                    try {
                        userOutput.writeRow(userRow.buildList());
                    } catch (IOException e) {
                        throw new EvaluationException("error writing user row", e);
                    }
                    userRow.clear();
                }

                progress.advance();
            }

            progress.finish();
            testTimer.stop();
            logger.info("Tested {} in {}", algorithm.getName(), testTimer);
            outputRow.add("BuildTime", buildTimer.elapsed(TimeUnit.MILLISECONDS) * 0.001);
            outputRow.add("TestTime", testTimer.elapsed(TimeUnit.MILLISECONDS) * 0.001);
            for (ConditionEvaluator eval : accumulators) {
                outputRow.addAll(eval.finish());
            }
        } catch (UncheckedInterruptException ex) {
            logger.info("evaluation interrupted");
            throw ex;
        } catch (Throwable th) {
            logger.error("Error evaluating " + algorithm + " on " + dataSet, th);
            throw th;
        }

        try {
            globalOutput.writeRow(outputRow.buildList());
        } catch (IOException e) {
            throw new EvaluationException("error writing output row", e);
        }
    }

    private LenskitRecommender buildRecommender() throws RecommenderBuildException {
        logger.debug("Starting recommender build");
        LenskitConfiguration dataConfig = new LenskitConfiguration(sharedConfig);
        dataSet.configure(dataConfig);

        LenskitConfiguration extraConfig = new LenskitConfiguration();

        // Fix the train items
        LongSet trainItems = dataSet.getTrainingData().getItemDAO().getItemIds();
        LongSet allItems = dataSet.getAllItems();
        if (!trainItems.containsAll(allItems)) {
            logger.info("train data is missing items, overriding item DAO");
            extraConfig.bind(ItemDAO.class).to(new ItemListItemDAO(allItems));
        }

        DAGNode<Component, Dependency> cfgGraph = algorithm.buildRecommenderGraph(dataConfig);
        if (mergePool != null) {
            logger.debug("deduplicating configuration graph");
            synchronized (mergePool) {
                cfgGraph = mergePool.merge(cfgGraph);
            }
        }
        DAGNode<Component, Dependency> graph;
        if (cache == null) {
            logger.debug("Building directly without a cache");
            RecommenderInstantiator ri = RecommenderInstantiator.create(cfgGraph);
            graph = ri.instantiate();
        } else {
            logger.debug("Instantiating graph with a cache");
            try {
                Set<DAGNode<Component, Dependency>> nodes = GraphtUtils.getShareableNodes(cfgGraph);
                logger.debug("resolving {} nodes", nodes.size());
                graph = NodeProcessors.processNodes(cfgGraph, nodes, cache);
                logger.debug("graph went from {} to {} nodes", cfgGraph.getReachableNodes().size(),
                        graph.getReachableNodes().size());
            } catch (InjectionException e) {
                logger.error("Error encountered while pre-processing algorithm components for sharing", e);
                throw new RecommenderBuildException("Pre-processing of algorithm components for sharing failed.",
                        e);
            }
        }
        return new LenskitRecommender(graph);
    }

    /**
     * Execute this job immediately.
     */
    public void execute() {
        compute();
    }
}