com.cloudera.oryx.als.computation.iterate.row.ComputeUserAPFn.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.als.computation.iterate.row.ComputeUserAPFn.java

Source

/*
 * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.als.computation.iterate.row;

import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.crunch.CrunchRuntimeException;
import org.apache.crunch.Emitter;
import org.apache.crunch.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Arrays;

import com.cloudera.oryx.als.common.StringLongMapping;
import com.cloudera.oryx.common.collection.LongObjectMap;
import com.cloudera.oryx.common.collection.LongSet;
import com.cloudera.oryx.common.io.DelimitedDataUtils;
import com.cloudera.oryx.common.iterator.FileLineIterable;
import com.cloudera.oryx.common.math.SimpleVectorMath;
import com.cloudera.oryx.common.servcomp.Store;
import com.cloudera.oryx.computation.common.fn.OryxDoFn;

public final class ComputeUserAPFn extends OryxDoFn<Pair<Long, float[]>, Double> {

    private static final Logger log = LoggerFactory.getLogger(ComputeUserAPFn.class);

    private final YState yState;
    private LongObjectMap<LongSet> testData;

    public ComputeUserAPFn(YState yState) {
        this.yState = yState;
    }

    @Override
    public void initialize() {
        super.initialize();
        Store store = Store.get();
        testData = new LongObjectMap<>();
        String prefix = getConfiguration().get(RowStep.MAP_KEY);
        try {
            for (String filePrefix : store.list(prefix, true)) {
                for (CharSequence line : new FileLineIterable(store.readFrom(filePrefix))) {
                    String[] columns = DelimitedDataUtils.decode(line);
                    long userID = StringLongMapping.toLong(columns[0]);
                    long itemID = StringLongMapping.toLong(columns[1]);
                    LongSet itemIDs = testData.get(userID);
                    if (itemIDs == null) {
                        itemIDs = new LongSet();
                        testData.put(userID, itemIDs);
                    }
                    itemIDs.add(itemID);
                }
            }
        } catch (IOException ioe) {
            throw new CrunchRuntimeException(ioe);
        }
    }

    @Override
    public void process(Pair<Long, float[]> input, Emitter<Double> emitter) {

        LongSet ids = testData.get(input.first());
        if (ids == null || ids.isEmpty()) {
            return;
        }

        float[] userVector = input.second();
        LongObjectMap<float[]> Y = yState.getY();
        long[] itemIDs = ids.toArray();

        double[] scores = new double[itemIDs.length];
        for (int i = 0; i < itemIDs.length; i++) {
            long itemID = itemIDs[i];
            float[] itemVector = Y.get(itemID);
            if (itemVector == null) {
                continue;
            }
            scores[i] = SimpleVectorMath.dot(userVector, itemVector);
        }

        int[] rank = new int[itemIDs.length];

        for (LongObjectMap.MapEntry<float[]> entry : Y.entrySet()) {
            double score = SimpleVectorMath.dot(userVector, entry.getValue());
            for (int i = 0; i < itemIDs.length; i++) {
                if (score > scores[i]) {
                    rank[i]++;
                }
            }
        }

        Arrays.sort(rank);

        Mean precision = new Mean();
        double totalPrecisionTimesRelevance = 0.0;
        for (int i = 0; i < rank.length; i++) {
            int relevantRetrieved = i + 1;
            int precisionAt = rank[i] + 1;
            precision.increment((double) relevantRetrieved / precisionAt);
            totalPrecisionTimesRelevance += precision.getResult();
        }
        double averagePrecision = totalPrecisionTimesRelevance / rank.length;

        log.info("Average precision: {}", averagePrecision);

        emitter.emit(averagePrecision);
    }

}