com.anhth12.lambda.app.serving.als.Because.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.app.serving.als.Because.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.app.serving.als;

import com.anhth12.lambda.app.serving.CSVMessageBodyWriter;
import com.anhth12.lambda.app.serving.IDValue;
import com.anhth12.lambda.app.serving.LambdaServingException;
import com.anhth12.lambda.app.serving.als.model.ALSServingModel;
import com.anhth12.lambda.common.collection.Pair;
import com.anhth12.lambda.common.collection.PairComparators;
import com.anhth12.lambda.common.math.VectorMath;
import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import java.util.List;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.MediaType;

/**
 *
 * @author Tong Hoang Anh
 */
@Path("/because")
public final class Because extends AbstractALSResource {

    @GET
    @Path("{userID}/{itemID}")
    @Produces({ MediaType.TEXT_PLAIN, CSVMessageBodyWriter.TEXT_CSV, MediaType.APPLICATION_JSON })
    public List<IDValue> get(@PathParam("userID") String userID, @PathParam("itemID") String itemID,
            @DefaultValue("10") @QueryParam("howMany") int howMany,
            @DefaultValue("0") @QueryParam("offset") int offset) throws LambdaServingException {
        check(howMany > 0, "howMany must be positive");
        check(offset >= 0, "offset must be non-negative");

        ALSServingModel model = getALSServingModel();
        float[] itemVector = model.getItemVector(itemID);
        checkExists(itemVector != null, itemID);
        List<Pair<String, float[]>> knownItemVectors = model.getKnowItemVectorsForUser(userID);

        checkExists(knownItemVectors != null, itemID);

        Iterable<Pair<String, Double>> idSimilarities = Iterables.transform(knownItemVectors,
                new CosineSimilarityFunction(itemVector));

        Ordering<Pair<?, Double>> ordering = Ordering.from(PairComparators.<Double>bySecond());

        return toIDValueResponse(ordering.greatestOf(idSimilarities, howMany + offset), howMany, offset);
    }

    private static final class CosineSimilarityFunction
            implements Function<Pair<String, float[]>, Pair<String, Double>> {

        private final float[] itemVector;
        private final double itemVectorNorm;

        public CosineSimilarityFunction(float[] itemVector) {
            this.itemVector = itemVector;
            this.itemVectorNorm = VectorMath.norm(itemVector);
        }

        @Override
        public Pair<String, Double> apply(Pair<String, float[]> f) {
            float[] otherItemVector = f.getSecond();
            double cosineSimilarity = VectorMath.dot(itemVector, otherItemVector)
                    / (itemVectorNorm * VectorMath.norm(otherItemVector));
            return new Pair<>(f.getFirst(), cosineSimilarity);
        }

    }

}