List of usage examples for org.apache.spark.api.java.function PairFlatMapFunction PairFlatMapFunction
PairFlatMapFunction
From source file:$package.SparkPageRankProgram.java
License:Apache License
@Override public void run(JavaSparkExecutionContext sec) throws Exception { JavaSparkContext jsc = new JavaSparkContext(); LOG.info("Processing backlinkURLs data"); JavaPairRDD<Long, String> backlinkURLs = sec.fromStream("backlinkURLStream", String.class); int iterationCount = getIterationCount(sec); LOG.info("Grouping data by key"); // Grouping backlinks by unique URL in key JavaPairRDD<String, Iterable<String>> links = backlinkURLs.values() .mapToPair(new PairFunction<String, String, String>() { @Override/*w ww. j ava2 s. co m*/ public Tuple2<String, String> call(String s) { String[] parts = SPACES.split(s); return new Tuple2<>(parts[0], parts[1]); } }).distinct().groupByKey().cache(); // Initialize default rank for each key URL JavaPairRDD<String, Double> ranks = links.mapValues(new Function<Iterable<String>, Double>() { @Override public Double call(Iterable<String> rs) { return 1.0; } }); // Calculates and updates URL ranks continuously using PageRank algorithm. for (int current = 0; current < iterationCount; current++) { LOG.debug("Processing data with PageRank algorithm. Iteration {}/{}", current + 1, (iterationCount)); // Calculates URL contributions to the rank of other URLs. JavaPairRDD<String, Double> contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() { @Override public Iterable<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) { LOG.debug("Processing {} with rank {}", s._1(), s._2()); int urlCount = Iterables.size(s._1()); List<Tuple2<String, Double>> results = new ArrayList<>(); for (String n : s._1()) { results.add(new Tuple2<>(n, s._2() / urlCount)); } return results; } }); // Re-calculates URL ranks based on backlink contributions. ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() { @Override public Double call(Double sum) { return 0.15 + sum * 0.85; } }); } LOG.info("Writing ranks data"); final ServiceDiscoverer discoveryServiceContext = sec.getServiceDiscoverer(); final Metrics sparkMetrics = sec.getMetrics(); JavaPairRDD<byte[], Integer> ranksRaw = ranks .mapToPair(new PairFunction<Tuple2<String, Double>, byte[], Integer>() { @Override public Tuple2<byte[], Integer> call(Tuple2<String, Double> tuple) throws Exception { LOG.debug("URL {} has rank {}", Arrays.toString(tuple._1().getBytes(Charsets.UTF_8)), tuple._2()); URL serviceURL = discoveryServiceContext.getServiceURL(SparkPageRankApp.SERVICE_HANDLERS); if (serviceURL == null) { throw new RuntimeException( "Failed to discover service: " + SparkPageRankApp.SERVICE_HANDLERS); } try { URLConnection connection = new URL(serviceURL, String.format("%s/%s", SparkPageRankApp.SparkPageRankServiceHandler.TRANSFORM_PATH, tuple._2().toString())).openConnection(); try (BufferedReader reader = new BufferedReader( new InputStreamReader(connection.getInputStream(), Charsets.UTF_8))) { String pr = reader.readLine(); if ((Integer.parseInt(pr)) == POPULAR_PAGE_THRESHOLD) { sparkMetrics.count(POPULAR_PAGES, 1); } else if (Integer.parseInt(pr) <= UNPOPULAR_PAGE_THRESHOLD) { sparkMetrics.count(UNPOPULAR_PAGES, 1); } else { sparkMetrics.count(REGULAR_PAGES, 1); } return new Tuple2<>(tuple._1().getBytes(Charsets.UTF_8), Integer.parseInt(pr)); } } catch (Exception e) { LOG.warn("Failed to read the Stream for service {}", SparkPageRankApp.SERVICE_HANDLERS, e); throw Throwables.propagate(e); } } }); // Store calculated results in output Dataset. // All calculated results are stored in one row. // Each result, the calculated URL rank based on backlink contributions, is an entry of the row. // The value of the entry is the URL rank. sec.saveAsDataset(ranksRaw, "ranks"); LOG.info("PageRanks successfuly computed and written to \"ranks\" dataset"); }
From source file:biz.hangyang.knnspark.spark.KNNClassifySpark.java
public static JavaPairRDD<Entity, Object> calKDistance(final String trainingDataPath, String testingDataPath, final int k, final Map<Object, Double> weightMap, JavaSparkContext sc, int partition, final Accumulator<Integer> accum) { JavaRDD<String> testingDataRDD = sc.textFile(testingDataPath, partition); //?Entity//from w w w . j a v a2 s . c om JavaRDD<Entity> testingEntityRDD = testingDataRDD.map(new Function<String, Entity>() { @Override public Entity call(String line) throws Exception { return new GeneEntity(line); } }); //??????K??KV JavaPairRDD<Entity, KDistance> ekRDD = testingEntityRDD .mapPartitionsToPair(new PairFlatMapFunction<Iterator<Entity>, Entity, KDistance>() { @Override public Iterable<Tuple2<Entity, KDistance>> call(Iterator<Entity> t) throws Exception { //?PARTITION? List<Entity> entityList = new ArrayList<>(); while (t.hasNext()) { entityList.add(t.next()); } //??LIST List<KDistance> kDistanceList = new ArrayList<>(); for (int i = 0; i < entityList.size(); i++) { kDistanceList.add(new KDistance(k)); } //???hdfs Configuration conf = new Configuration(); FileSystem fs = FileSystem.get(URI.create(trainingDataPath), conf); FSDataInputStream in = fs.open(new Path(trainingDataPath)); BufferedReader br = new BufferedReader(new InputStreamReader(in, "UTF-8")); String line; while ((line = br.readLine()) != null) { Entity lineEntity = new GeneEntity(line); for (int i = 0; i < entityList.size(); i++) { kDistanceList.get(i).add(new DemoDistanceCatagory( lineEntity.distance(entityList.get(i)), lineEntity.category)); } } List<Tuple2<Entity, KDistance>> tList = new ArrayList<>(); for (int i = 0; i < entityList.size(); i++) { tList.add(new Tuple2<>(entityList.get(i), kDistanceList.get(i))); } return tList; } }); JavaPairRDD<Entity, Object> eoRDD = ekRDD .mapToPair(new PairFunction<Tuple2<Entity, KDistance>, Entity, Object>() { @Override public Tuple2<Entity, Object> call(Tuple2<Entity, KDistance> t) throws Exception { KDistance kDistance = t._2(); //??? Object catagory = KDistance.getCatagory(kDistance.get(), weightMap); if (t._1().category.equals(catagory)) { accum.add(1); } return new Tuple2<>(t._1(), catagory); } }); return eoRDD; }
From source file:cn.com.bsfit.frms.spark.PageRank.java
License:Apache License
public static void main(String[] args) throws Exception { if (args.length < 2) { System.err.println("Usage: JavaPageRank <file> <number_of_iterations>"); System.exit(1);/* w w w . j a va 2s. co m*/ } showWarning(); SparkSession spark = SparkSession.builder().appName("JavaPageRank").getOrCreate(); // Loads in input file. It should be in format of: // URL neighbor URL // URL neighbor URL // URL neighbor URL // ... JavaRDD<String> lines = spark.read().textFile(args[0]).javaRDD(); // Loads all URLs from input file and initialize their neighbors. JavaPairRDD<String, Iterable<String>> links = lines.mapToPair(new PairFunction<String, String, String>() { private static final long serialVersionUID = 1L; @Override public Tuple2<String, String> call(String s) { String[] parts = SPACES.split(s); return new Tuple2<>(parts[0], parts[1]); } }).distinct().groupByKey().cache(); // Loads all URLs with other URL(s) link to from input file and // initialize ranks of them to one. JavaPairRDD<String, Double> ranks = links.mapValues(new Function<Iterable<String>, Double>() { private static final long serialVersionUID = 1L; @Override public Double call(Iterable<String> rs) { return 1.0; } }); // Calculates and updates URL ranks continuously using PageRank // algorithm. for (int current = 0; current < Integer.parseInt(args[1]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD<String, Double> contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() { private static final long serialVersionUID = 1L; @Override public Iterator<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) { int urlCount = Iterables.size(s._1); List<Tuple2<String, Double>> results = new ArrayList<>(); for (String n : s._1) { results.add(new Tuple2<>(n, s._2() / urlCount)); } return results.iterator(); } }); // Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() { private static final long serialVersionUID = 1L; @Override public Double call(Double sum) { return 0.15 + sum * 0.85; } }); } // Collects all URL ranks and dump them to console. List<Tuple2<String, Double>> output = ranks.collect(); for (Tuple2<?, ?> tuple : output) { System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } spark.stop(); }
From source file:co.cask.cdap.etl.batch.spark.ETLSparkProgram.java
License:Apache License
@Override public void run(DatasetContext datasetContext) throws Exception { BatchPhaseSpec phaseSpec = GSON.fromJson(sec.getSpecification().getProperty(Constants.PIPELINEID), BatchPhaseSpec.class); Set<StageInfo> aggregators = phaseSpec.getPhase().getStagesOfType(BatchAggregator.PLUGIN_TYPE); String aggregatorName = null; if (!aggregators.isEmpty()) { aggregatorName = aggregators.iterator().next().getName(); }/*w w w .j av a 2 s . co m*/ SparkBatchSourceFactory sourceFactory; SparkBatchSinkFactory sinkFactory; Integer numPartitions; try (InputStream is = new FileInputStream(sec.getLocalizationContext().getLocalFile("ETLSpark.config"))) { sourceFactory = SparkBatchSourceFactory.deserialize(is); sinkFactory = SparkBatchSinkFactory.deserialize(is); numPartitions = new DataInputStream(is).readInt(); } JavaPairRDD<Object, Object> rdd = sourceFactory.createRDD(sec, jsc, Object.class, Object.class); JavaPairRDD<String, Object> resultRDD = doTransform(sec, jsc, datasetContext, phaseSpec, rdd, aggregatorName, numPartitions); Set<StageInfo> stagesOfTypeSparkSink = phaseSpec.getPhase().getStagesOfType(SparkSink.PLUGIN_TYPE); Set<String> namesOfTypeSparkSink = new HashSet<>(); for (StageInfo stageInfo : stagesOfTypeSparkSink) { namesOfTypeSparkSink.add(stageInfo.getName()); } for (final String sinkName : phaseSpec.getPhase().getSinks()) { JavaPairRDD<String, Object> filteredResultRDD = resultRDD .filter(new Function<Tuple2<String, Object>, Boolean>() { @Override public Boolean call(Tuple2<String, Object> v1) throws Exception { return v1._1().equals(sinkName); } }); if (namesOfTypeSparkSink.contains(sinkName)) { SparkSink sparkSink = sec.getPluginContext().newPluginInstance(sinkName); sparkSink.run(new BasicSparkExecutionPluginContext(sec, jsc, datasetContext, sinkName), filteredResultRDD.values()); } else { JavaPairRDD<Object, Object> sinkRDD = filteredResultRDD .flatMapToPair(new PairFlatMapFunction<Tuple2<String, Object>, Object, Object>() { @Override public Iterable<Tuple2<Object, Object>> call(Tuple2<String, Object> input) throws Exception { List<Tuple2<Object, Object>> result = new ArrayList<>(); KeyValue<Object, Object> keyValue = (KeyValue<Object, Object>) input._2(); result.add(new Tuple2<>(keyValue.getKey(), keyValue.getValue())); return result; } }); sinkFactory.writeFromRDD(sinkRDD, sec, sinkName, Object.class, Object.class); } } }
From source file:co.cask.cdap.spark.app.SparkLogParser.java
License:Apache License
@Override public void run(JavaSparkExecutionContext sec) throws Exception { JavaSparkContext jsc = new JavaSparkContext(); Map<String, String> runtimeArguments = sec.getRuntimeArguments(); String inputFileSet = runtimeArguments.get("input"); final String outputTable = runtimeArguments.get("output"); JavaPairRDD<LongWritable, Text> input = sec.fromDataset(inputFileSet); final JavaPairRDD<String, String> aggregated = input .mapToPair(new PairFunction<Tuple2<LongWritable, Text>, LogKey, LogStats>() { @Override/* w w w . jav a 2 s . c o m*/ public Tuple2<LogKey, LogStats> call(Tuple2<LongWritable, Text> input) throws Exception { return SparkAppUsingGetDataset.parse(input._2()); } }).reduceByKey(new Function2<LogStats, LogStats, LogStats>() { @Override public LogStats call(LogStats stats1, LogStats stats2) throws Exception { return stats1.aggregate(stats2); } }) .mapPartitionsToPair(new PairFlatMapFunction<Iterator<Tuple2<LogKey, LogStats>>, String, String>() { @Override public Iterable<Tuple2<String, String>> call(Iterator<Tuple2<LogKey, LogStats>> itor) throws Exception { final Gson gson = new Gson(); return Lists.newArrayList(Iterators.transform(itor, new Function<Tuple2<LogKey, LogStats>, Tuple2<String, String>>() { @Override public Tuple2<String, String> apply(Tuple2<LogKey, LogStats> input) { return new Tuple2<>(gson.toJson(input._1()), gson.toJson(input._2())); } })); } }); // Collect all data to driver and write to dataset directly. That's the intend of the test. sec.execute(new TxRunnable() { @Override public void run(DatasetContext context) throws Exception { KeyValueTable kvTable = context.getDataset(outputTable); for (Map.Entry<String, String> entry : aggregated.collectAsMap().entrySet()) { kvTable.write(entry.getKey(), entry.getValue()); } } }); }
From source file:com.andado.spark.examples.JavaPageRank.java
License:Apache License
public static void main(String[] args) throws Exception { if (args.length < 2) { System.err.println("Usage: JavaPageRank <file> <number_of_iterations>"); System.exit(1);/*from ww w. j a v a 2 s . c o m*/ } showWarning(); SparkSession spark = SparkSession.builder().appName("JavaPageRank").getOrCreate(); // Loads in input file. It should be in format of: // URL neighbor URL // URL neighbor URL // URL neighbor URL // ... JavaRDD<String> lines = spark.read().textFile(args[0]).javaRDD(); // Loads all URLs from input file and initialize their neighbors. JavaPairRDD<String, Iterable<String>> links = lines.mapToPair(new PairFunction<String, String, String>() { @Override public Tuple2<String, String> call(String s) { String[] parts = SPACES.split(s); return new Tuple2<>(parts[0], parts[1]); } }).distinct().groupByKey().cache(); // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. JavaPairRDD<String, Double> ranks = links.mapValues(new Function<Iterable<String>, Double>() { @Override public Double call(Iterable<String> rs) { return 1.0; } }); // Calculates and updates URL ranks continuously using PageRank algorithm. for (int current = 0; current < Integer.parseInt(args[1]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD<String, Double> contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() { @Override public Iterator<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) { int urlCount = Iterables.size(s._1); List<Tuple2<String, Double>> results = new ArrayList<>(); for (String n : s._1) { results.add(new Tuple2<>(n, s._2() / urlCount)); } return results.iterator(); } }); // Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() { @Override public Double call(Double sum) { return 0.15 + sum * 0.85; } }); } // Collects all URL ranks and dump them to console. List<Tuple2<String, Double>> output = ranks.collect(); for (Tuple2<?, ?> tuple : output) { System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } spark.stop(); }
From source file:com.anhth12.lambda.app.ml.als.Evaluation.java
/** * Compute AUC (area under the ROC curve) as a recommender evaluation * * @param sparkContext/*from ww w .j ava2 s .co m*/ * @param mfModel * @param positiveData * @return */ static double areaUnderCurve(JavaSparkContext sparkContext, MatrixFactorizationModel mfModel, JavaRDD<Rating> positiveData) { JavaPairRDD<Integer, Integer> positiveUserProducts = positiveData .mapToPair(new PairFunction<Rating, Integer, Integer>() { @Override public Tuple2<Integer, Integer> call(Rating t) throws Exception { return new Tuple2<>(t.user(), t.product()); } }); JavaPairRDD<Integer, Iterable<Rating>> positivePredictions = predictAll(mfModel, positiveData, positiveUserProducts); final Broadcast<List<Integer>> allItemIDsBC = sparkContext .broadcast(positiveUserProducts.values().distinct().collect()); JavaPairRDD<Integer, Integer> negativeUserProducts = positiveUserProducts.groupByKey() .flatMapToPair(new PairFlatMapFunction<Tuple2<Integer, Iterable<Integer>>, Integer, Integer>() { private final RandomGenerator random = RandomManager.getRandom(); @Override public Iterable<Tuple2<Integer, Integer>> call( Tuple2<Integer, Iterable<Integer>> userIDsAndItemIDs) throws Exception { Integer userID = userIDsAndItemIDs._1; Collection<Integer> positiveItemIDs = Sets.newHashSet(userIDsAndItemIDs._2()); int numPositive = positiveItemIDs.size(); Collection<Tuple2<Integer, Integer>> negative = new ArrayList<>(numPositive); List<Integer> allItemIDs = allItemIDsBC.value(); int numItems = allItemIDs.size(); for (int i = 0; i < numItems && negative.size() < numPositive; i++) { Integer itemID = allItemIDs.get(random.nextInt(numItems)); if (!positiveItemIDs.contains(itemID)) { negative.add(new Tuple2<>(userID, itemID)); } } return negative; } }); JavaPairRDD<Integer, Iterable<Rating>> negativePredictions = predictAll(mfModel, positiveData, negativeUserProducts); return positivePredictions.join(negativePredictions).values() .mapToDouble(new DoubleFunction<Tuple2<Iterable<Rating>, Iterable<Rating>>>() { @Override public double call(Tuple2<Iterable<Rating>, Iterable<Rating>> t) throws Exception { //AUC is also the probability that random positive examples //ranking higher than random examples at large. Heare wer compare all random negative //examples to all positive exampls and rapost the totals as an alternative //computatioin for AUC long correct = 0; long total = 0; for (Rating positive : t._1()) { for (Rating negative : t._2()) { if (positive.rating() > negative.rating()) { correct++; } total++; } } return (double) correct / total; } }).mean(); }
From source file:com.audaque.instancematch.match.GenerateSignature.java
/** * ??q,hashNum???tuple2key???valueJavaPairRDD * @param srcFile ?// www .j ava2 s. c om * @param seedFile * @param q q-gramq * @param hashNum * @param sc * @return * @throws IOException */ public static List<Tuple2<String, String>> generateSignature(String srcFile, String seedFile, final int q, int hashNum, JavaSparkContext sc) throws IOException { // final int[] seeds = Hash.loadSeeds("res/seed0.txt", hashNum);// ?? JavaRDD<String> srcRDD = sc.textFile(srcFile, 40); JavaRDD<String> seedRDD = sc.textFile(seedFile); final List<String> seedList = seedRDD.collect(); JavaPairRDD<String, String> seeds_hashRDD = srcRDD .flatMapToPair(new PairFlatMapFunction<String, String, String>() { @Override public Iterable<Tuple2<String, String>> call(String line) throws Exception { List<Tuple2<String, String>> list = new ArrayList<Tuple2<String, String>>(); for (int i = 0; i < seedList.size(); i++) { int hash; if (q < 0) { hash = Hash.RSHash(line, Integer.valueOf(seedList.get(i))); } else { hash = QGramHash.RSHash(line, Integer.valueOf(seedList.get(i)), q); } list.add(new Tuple2<String, String>(seedList.get(i), String.valueOf(hash))); } return list; } }); JavaPairRDD<String, String> seed_hashRDD = seeds_hashRDD .reduceByKey(new Function2<String, String, String>() { @Override public String call(String v1, String v2) throws Exception { if (Integer.valueOf(v1) < Integer.valueOf(v2)) { return v1; } else { return v2; } } }); // seeds_hashRDD.sortByKey().saveAsTextFile("hdfs://172.16.1.101:8020/user/ALGO/result"); return seed_hashRDD.sortByKey().collect(); }
From source file:com.audaque.instancematch.match.GenerateSignature2.java
public static List<Tuple2<String, String>> generateSignature(final String srcFile, String seedFile, final int q, int hashNum, JavaSparkContext sc) { JavaRDD<String> seedRDD = sc.textFile(seedFile, 40); JavaPairRDD<String, String> seeds_hashRDD = seedRDD .mapPartitionsToPair(new PairFlatMapFunction<Iterator<String>, String, String>() { @Override/*w ww . j a v a 2 s.c om*/ public Iterable<Tuple2<String, String>> call(Iterator<String> seed) throws Exception { List<Integer> seedList = new ArrayList<Integer>(); while (seed.hasNext()) { seedList.add(Integer.valueOf(seed.next())); } int[] minHash = new int[seedList.size()]; for (int i = 0; i < minHash.length; i++) { minHash[i] = Integer.MAX_VALUE; } //???hdfs Configuration conf = new Configuration(); FileSystem fs = FileSystem.get(URI.create(srcFile), conf); FSDataInputStream in = fs.open(new Path(srcFile)); BufferedReader br = new BufferedReader(new InputStreamReader(in, "UTF-8")); String line; while ((line = br.readLine()) != null) { for (int i = 0; i < seedList.size(); i++) { int hash; if (q < 0) { hash = Hash.RSHash(line, seedList.get(i)); } else { hash = QGramHash.RSHash(line, seedList.get(i), q); } if (hash < minHash[i]) { minHash[i] = hash; } } } List<Tuple2<String, String>> tList = new ArrayList<Tuple2<String, String>>(); for (int i = 0; i < seedList.size(); i++) { tList.add(new Tuple2<String, String>(String.valueOf(seedList.get(i)), String.valueOf(minHash[i]))); } return tList; } }); return seeds_hashRDD.sortByKey().collect(); }
From source file:com.cloudera.oryx.app.batch.mllib.als.Evaluation.java
License:Open Source License
/** * Computes AUC (area under the ROC curve) as a recommender evaluation metric. * Really, it computes what might be described as "Mean AUC", as it computes AUC per * user and averages them./* ww w . j ava 2 s .c o m*/ */ static double areaUnderCurve(JavaSparkContext sparkContext, MatrixFactorizationModel mfModel, JavaRDD<Rating> positiveData) { // This does not use Spark's BinaryClassificationMetrics.areaUnderROC because it // is intended to operate on one large set of (score,label) pairs. The computation // here is really many small AUC problems, for which a much faster direct computation // is available. // Extract all positive (user,product) pairs JavaPairRDD<Integer, Integer> positiveUserProducts = positiveData .mapToPair(rating -> new Tuple2<>(rating.user(), rating.product())); JavaPairRDD<Integer, Iterable<Rating>> positivePredictions = predictAll(mfModel, positiveData, positiveUserProducts); // All distinct item IDs, to be broadcast Broadcast<List<Integer>> allItemIDsBC = sparkContext .broadcast(positiveUserProducts.values().distinct().collect()); JavaPairRDD<Integer, Integer> negativeUserProducts = positiveUserProducts.groupByKey() .flatMapToPair(new PairFlatMapFunction<Tuple2<Integer, Iterable<Integer>>, Integer, Integer>() { private final RandomGenerator random = RandomManager.getRandom(); @Override public Iterator<Tuple2<Integer, Integer>> call( Tuple2<Integer, Iterable<Integer>> userIDsAndItemIDs) { Integer userID = userIDsAndItemIDs._1(); Collection<Integer> positiveItemIDs = Sets.newHashSet(userIDsAndItemIDs._2()); int numPositive = positiveItemIDs.size(); Collection<Tuple2<Integer, Integer>> negative = new ArrayList<>(numPositive); List<Integer> allItemIDs = allItemIDsBC.value(); int numItems = allItemIDs.size(); // Sample about as many negative examples as positive for (int i = 0; i < numItems && negative.size() < numPositive; i++) { Integer itemID = allItemIDs.get(random.nextInt(numItems)); if (!positiveItemIDs.contains(itemID)) { negative.add(new Tuple2<>(userID, itemID)); } } return negative.iterator(); } }); JavaPairRDD<Integer, Iterable<Rating>> negativePredictions = predictAll(mfModel, positiveData, negativeUserProducts); return positivePredictions.join(negativePredictions).values().mapToDouble(t -> { // AUC is also the probability that random positive examples // rank higher than random examples at large. Here we compare all random negative // examples to all positive examples and report the totals as an alternative // computation for AUC long correct = 0; long total = 0; for (Rating positive : t._1()) { for (Rating negative : t._2()) { if (positive.rating() > negative.rating()) { correct++; } total++; } } if (total == 0) { return 0.0; } return (double) correct / total; }).mean(); }