Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.mapr.stats.bandit; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.mapr.stats.random.BinomialDistributionSampler; import com.mapr.stats.random.DistributionGenerator; import com.mapr.stats.random.DistributionWithMean; import com.mapr.stats.random.NormalDistributionSampler; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.stats.OnlineSummarizer; import java.io.FileNotFoundException; import java.io.PrintWriter; import java.lang.reflect.InvocationTargetException; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; /** * Simulate a two-armed bandit playing against a beta-Bayesian model. * <p/> * The output indicates the quantiles of the distribution for regret relative to the optimal pick. * The regret distribution is estimated by picking two random conversion probabilities and then * running the beta-Bayesian model for a number of steps. The regret is computed by taking the * expectation for the optimal choice and subtracting from the actual percentage of conversion * achieved. On average, this should be somewhat negative since the model has to spend some effort * examining the sub-optimal choice. The median, 25 and 75%-ile marks all scale downward fairly * precisely with the square root of the number of trials which is to be expected from theoretical * considerations. * <p/> * The beta-Bayesian model works by keeping an estimate of the posterior distribution for the * conversion probability for each of the bandits. We take a uniform distribution as the prior so * the posterior is a beta distribution. The model samples probabilities from the two posterior * distributions and chooses the model whose sample is larger. As data is collected for the two * bandits, the better of the bandits will quickly have a pretty narrow posterior distribution and * the lesser bandit will rarely have a sampled probability higher than the better bandit. This * means that we will stop getting data from the less bandit, but only when there is essentially no * chance that it is better. */ public class BanditTrainer { private static final int BUCKET_SIZE = 1; public static void main(String[] args) throws FileNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, InterruptedException { int threads = 16; if (args.length > 0) { threads = Integer.parseInt(args[0]); } System.out.printf("regret\n"); ExecutorService ex = Executors.newFixedThreadPool(threads); List<Callable<Integer>> tasks = ImmutableList.of(new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-normal-0.1.tsv", "local-normal-0.1.tsv", 1000, 2, 10000, new GammaNormalBayesFactory(), new NormalDistributionSampler(0.1, new Random())); System.out.printf("2\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-epsilon-normal-1.tsv", "local-epsilon-normal-1.tsv", 1000, 2, 10000, new EpsilonGreedyFactory(0.05), new NormalDistributionSampler(1, new Random())); System.out.printf("2e\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-normal-1.tsv", "local-normal-1.tsv", 300, 2, 200000, new GammaNormalBayesFactory(), new NormalDistributionSampler(1, new Random())); System.out.printf("normal 1\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-normal-10x0.1.tsv", "local-normal-10x0.1.tsv", 1000, 10, 1000, new GammaNormalBayesFactory(), new NormalDistributionSampler(0.1, new Random())); System.out.printf("10\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-normal-100x0.1.tsv", "local-normal-100x0.1.tsv", 1000, 100, 1000, new GammaNormalBayesFactory(), new NormalDistributionSampler(.1, new Random())); System.out.printf("100\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret.tsv", "local.tsv", 1000, 2, 1000, new BetaBayesFactory(), new BinomialDistributionSampler(1, 1, new Random())); System.out.printf("2\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-100.tsv", "local-100.tsv", 1000, 100, 1000, new BetaBayesFactory(), new BinomialDistributionSampler(1, 1, new Random())); System.out.printf("100\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }, new Callable<Integer>() { @Override public Integer call() { try { totalRegret("regret-20.tsv", "local-20.tsv", 1000, 20, 1000, new BetaBayesFactory(), new BinomialDistributionSampler(1, 1, new Random())); System.out.printf("20\n"); } catch (FileNotFoundException e) { e.printStackTrace(); } return null; } }); ex.invokeAll(tasks); ex.shutdown(); System.out.printf("All done"); // System.out.printf("error rates\n"); // errorRate("errors.tsv"); // System.out.printf("commit time\n"); // commitTime("commit.tsv", 3000, 0.1, 0.12, 2000); // System.out.printf("done\n"); } /** * Records which bandit was chosen for many runs of the same scenario. This output is kind of big * an hard to digest visually. As such, it is probably better to reduce this using a mean. In R, * this can be done like this: * <pre> * plot(tapply(z$k, floor(z$i/10), mean), type='l') * </pre> * * @param outputFile Where to write results * @param n How many steps to follow * @param p1 First probability of reward * @param p2 Second probability of reward * @param cutoff Only keep results after this many steps * @return Average number of correct choices. * @throws java.io.FileNotFoundException If the directory holding the output directory doesn't exist. */ public static double commitTime(String outputFile, int n, double p1, double p2, int cutoff) throws FileNotFoundException { try (PrintWriter out = new PrintWriter(outputFile)) { Random gen = new Random(); out.printf("i\tk\n"); int impressions = 0; int correct = 0; for (int j = 0; j < 1000; j++) { // pick probabilities at random double[] p = { p1, p2 }; Arrays.sort(p); BetaBayesModel s = new BetaBayesModel(); for (int i = 0; i < n; i++) { int k = s.sample(); out.printf("%d\t%d\n", i, k); if (i > cutoff) { impressions++; correct += k; } final double u = gen.nextDouble(); boolean r = u <= p[k]; s.train(k, r ? 1 : 0); } } return (double) correct / impressions; } } /** * Computes error rate (the rate at which the sub-optimal choice is made as a function of the two * probabilities and the number of trials. The output report contains p1, p2, number-of-trials, * total-correct, total-correct-in-last-half. * <p/> * The commitTime output is probably more interesting. * * @param outputFile Where to write the data. * @throws java.io.FileNotFoundException If we can't open our output */ @Deprecated private static void errorRate(String outputFile) throws FileNotFoundException { try (PrintWriter out = new PrintWriter(outputFile)) { out.printf("p1\tp2\tn\twins\tlate\n"); Random gen = new Random(); for (int n : new int[] { 20, 50, 100, 200, 500, 1000, 2000, 5000 }) { System.out.printf("%d\n", n); for (int j = 0; j < 1000 * (n < 500 ? 10 : 1); j++) { // pick probabilities at random double[] p = { gen.nextDouble(), gen.nextDouble() }; // order them to make error interpretation easier Arrays.sort(p); BetaBayesModel s = new BetaBayesModel(); int wins = 0; int lateWins = 0; for (int i = 0; i < n; i++) { int k = s.sample(); final double u = gen.nextDouble(); boolean r = u <= p[k]; wins += r ? 1 : 0; if (i > n / 2) { lateWins += r ? 1 : 0; } s.train(k, r ? 1 : 0); } out.printf("%.3f\t%.3f\t%d\t%d\t%d\n", p[0], p[1], n, wins, lateWins); } } } } /** * Computes average regret relative to perfect knowledge given uniform random probabilities. The * output contains the quartiles for different numbers of trials. The quartiles are computed by * running many experiments for each specified number of trials. * <p/> * This can be plotted pretty much directly in R * <pre> * > x=read.delim(file='~/Apache/storm-aggregator/regret.tsv') * > bxp(list(com.mapr.stats=t(as.matrix(x[,2:6])), n=rep(1000,times=8),names=x$n)) * </pre> * * @param outputFile Where to put the output * @param sizes The different size experiments to use * @param replications Number of times to repeat the experiment * @param bandits How many bandits to simulate * @return Returns the average regret per trial * @throws java.io.FileNotFoundException If the output file can't be opened due to a missing directory. */ public static double averageRegret(String outputFile, int[] sizes, int replications, int bandits) throws FileNotFoundException { try (PrintWriter out = new PrintWriter(outputFile)) { double finalMedianRegret = 0; Random gen = new Random(); out.printf("n\tq0\tq1\tq2\tq3\tq4\n"); // for each horizon time span of interest for (int n : sizes) { System.out.printf("%d\n", n); OnlineSummarizer summary = new OnlineSummarizer(); // replicate the test many times for (int j = 0; j < replications; j++) { // pick probabilities at random double[] p = new double[bandits]; for (int k = 0; k < bandits; k++) { p[k] = gen.nextDouble(); } // order them to make error interpretation easier Arrays.sort(p); BetaBayesModel s = new BetaBayesModel(bandits, RandomUtils.getRandom()); int wins = 0; for (int i = 0; i < n; i++) { int k = s.sample(); final double u = gen.nextDouble(); boolean r = u <= p[k]; wins += r ? 1 : 0; s.train(k, r ? 1 : 0); } summary.add((double) wins / n - p[bandits - 1]); } out.printf("%d\t", n); for (int quartile = 0; quartile <= 4; quartile++) { out.printf("%.3f%s", summary.getQuartile(quartile), quartile < 4 ? "\t" : "\n"); } out.flush(); finalMedianRegret = summary.getMedian(); // System.out.printf("%.3f\n", summary.getMean()); } return finalMedianRegret; } } /** * Computes average regret relative to perfect knowledge given uniform random probabilities. The * output contains the quartiles for different numbers of trials. The quartiles are computed by * running many experiments for each specified number of trials. * <p/> * This can be plotted pretty much directly in R * <pre> * > x=read.delim(file='~/Apache/storm-aggregator/regret.tsv') * > bxp(list(com.mapr.stats=t(as.matrix(x[,2:6])), n=rep(1000,times=8),names=x$n)) * </pre> * * @param cumulativeOutput Where to write the cumulative regret results * @param perTurnOutput Where to write the per step regret results * @param replications How many times to replicate the experiment * @param bandits How many bandits to emulate * @param maxSteps Maximum number of trials to run per experiment * @param modelFactory How to construct the solver. * @param refSampler How to get reward distributions for bandits * @return An estimate of the average final cumulative regret * @throws java.io.FileNotFoundException If the output file can't be opened due to * a missing directory. */ public static double totalRegret(String cumulativeOutput, String perTurnOutput, int replications, int bandits, int maxSteps, BanditFactory modelFactory, DistributionGenerator refSampler) throws FileNotFoundException { List<OnlineSummarizer> cumulativeRegret = Lists.newArrayList(); List<OnlineSummarizer> localRegret = Lists.newArrayList(); List<Integer> steps = Lists.newArrayList(); List<Integer> localSteps = Lists.newArrayList(); Random gen = new Random(); // for each horizon time span of interest for (int j = 0; j < replications; j++) { BayesianBandit s = modelFactory.createBandit(bandits, gen); List<DistributionWithMean> refs = Lists.newArrayList(); for (int k = 0; k < bandits; k++) { refs.add(refSampler.nextDistribution()); } Collections.sort(refs); double wins = 0; int k = 0; int delta = 1; double totalRegret = 0; for (int i = 0; i < maxSteps; i++) { if (i > 50 * delta) { delta = bump(delta); } int choice = s.sample(); double r = refs.get(choice).nextDouble(); totalRegret += refs.get(bandits - 1).getMean() - refs.get(choice).getMean(); if ((i + 1) % delta == 0) { if (cumulativeRegret.size() <= k) { cumulativeRegret.add(new OnlineSummarizer()); steps.add(i + 1); } cumulativeRegret.get(k).add(totalRegret); k++; } if (localRegret.size() <= i / BUCKET_SIZE) { localRegret.add(new OnlineSummarizer()); localSteps.add(i); } double thisTrialRegret = refs.get(bandits - 1).getMean() - refs.get(choice).getMean(); localRegret.get(i / BUCKET_SIZE).add(thisTrialRegret); wins += r; s.train(choice, r); } } printRegret(cumulativeOutput, cumulativeRegret, steps); printRegret(perTurnOutput, localRegret, localSteps); return cumulativeRegret.get(cumulativeRegret.size() - 1).getMedian(); } private static void printRegret(String outputFile, List<OnlineSummarizer> cumulativeRegret, List<Integer> steps) throws FileNotFoundException { try (PrintWriter out = new PrintWriter(outputFile)) { out.printf("n\tmean\n"); int k = 0; for (OnlineSummarizer summary : cumulativeRegret) { out.printf("%d\t%.4f\n", steps.get(k++), summary.getMean()); // for (int quartile = 0; quartile <= 4; quartile++) { // out.printf("%.3f%s", summary.getQuartile(quartile), quartile < 4 ? "\t" : "\n"); // } } out.flush(); } } private static int bump(int delta) { int multiplier = 1; while (delta >= 10) { multiplier *= 10; delta /= 10; } // steps each of 1,2,5 up to next level delta = (int) (4 * delta - delta * delta / 3 - 1.5); return delta * multiplier; } }