Coding pattern for random percentage branching?
EDIT: See edit at end for more elegant solution. I'll leave this in though.
You can use a NavigableMap
to store these methods mapped to their percentages.
NavigableMap<Double, Runnable> runnables = new TreeMap<>();
runnables.put(0.3, this::30PercentMethod);
runnables.put(1.0, this::70PercentMethod);
public static void runRandomly(Map<Double, Runnable> runnables) {
double percentage = Math.random();
for (Map.Entry<Double, Runnable> entry : runnables){
if (entry.getKey() < percentage) {
entry.getValue().run();
return; // make sure you only call one method
}
}
throw new RuntimeException("map not filled properly for " + percentage);
}
// or, because I'm still practicing streams by using them for everything
public static void runRandomly(Map<Double, Runnable> runnables) {
double percentage = Math.random();
runnables.entrySet().stream()
.filter(e -> e.getKey() < percentage)
.findFirst().orElseThrow(() ->
new RuntimeException("map not filled properly for " + percentage))
.run();
}
The NavigableMap
is sorted (e.g. HashMap
gives no guarantees of the entries) by keys, so you get the entries ordered by their percentages. This is relevant because if you have two items (3,r1),(7,r2), they result in the following entries: r1 = 0.3
and r2 = 1.0
and they need to be evaluated in this order (e.g. if they are evaluated in the reverse order the result would always be r2
).
As for the splitting, it should go something like this: With a Tuple class like this
static class Pair<X, Y>
{
public Pair(X f, Y s)
{
first = f;
second = s;
}
public final X first;
public final Y second;
}
You can create a map like this
// the parameter contains the (1,m1), (1,m2), (3,m3) pairs
private static Map<Double,Runnable> splitToPercentageMap(Collection<Pair<Integer,Runnable>> runnables)
{
// this adds all Runnables to lists of same int value,
// overall those lists are sorted by that int (so least probable first)
double total = 0;
Map<Integer,List<Runnable>> byNumber = new TreeMap<>();
for (Pair<Integer,Runnable> e : runnables)
{
total += e.first;
List<Runnable> list = byNumber.getOrDefault(e.first, new ArrayList<>());
list.add(e.second);
byNumber.put(e.first, list);
}
Map<Double,Runnable> targetList = new TreeMap<>();
double current = 0;
for (Map.Entry<Integer,List<Runnable>> e : byNumber.entrySet())
{
for (Runnable r : e.getValue())
{
double percentage = (double) e.getKey() / total;
current += percentage;
targetList.put(current, r);
}
}
return targetList;
}
And all of this added to a class
class RandomRunner {
private List<Integer, Runnable> runnables = new ArrayList<>();
public void add(int value, Runnable toRun) {
runnables.add(new Pair<>(value, toRun));
}
public void remove(Runnable toRemove) {
for (Iterator<Pair<Integer, Runnable>> r = runnables.iterator();
r.hasNext(); ) {
if (toRemove == r.next().second) {
r.remove();
break;
}
}
}
public void runRandomly() {
// split list, use code from above
}
}
EDIT :
Actually, the above is what you get if you get an idea stuck in your head and don't question it properly.
Keeping the RandomRunner
class interface, this is much easier:
class RandomRunner {
List<Runnable> runnables = new ArrayList<>();
public void add(int value, Runnable toRun) {
// add the methods as often as their weight indicates.
// this should be fine for smaller numbers;
// if you get lists with millions of entries, optimize
for (int i = 0; i < value; i++) {
runnables.add(toRun);
}
}
public void remove(Runnable r) {
Iterator<Runnable> myRunnables = runnables.iterator();
while (myRunnables.hasNext()) {
if (myRunnables.next() == r) {
myRunnables.remove();
}
}
public void runRandomly() {
if (runnables.isEmpty()) return;
// roll n-sided die
int runIndex = ThreadLocalRandom.current().nextInt(0, runnables.size());
runnables.get(runIndex).run();
}
}
All these answers seem quite complicated, so I'll just post the keep-it-simple alternative:
double rnd = Math.random()
if((rnd -= 0.6) < 0)
60percentmethod();
else if ((rnd -= 0.3) < 0)
30percentmethod();
else
10percentmethod();
Doesn't need changing other lines and one can quite easily see what happens, without digging into auxiliary classes. A small downside is that it doesn't enforce that percentages sum to 100%.
I am not sure if there is a common name to this, but I think I learned this as the wheel of fortune back in university.
It basically just works as you described: It receives a list of values and "frequency numbers" and one is chosen according to the weighted probabilities.
list = (1,a),(1,b),(2,c),(6,d)
total = list.sum()
rnd = random(0, total)
sum = 0
for i from 0 to list.size():
sum += list[i]
if sum >= rnd:
return list[i]
return list.last()
The list can be a function parameter if you want to generalize this.
This also works with floating point numbers and the numbers don't have to be normalized. If you normalize (to sum up to 1 for example), you can skip the list.sum()
part.
EDIT:
Due to demand here is an actual compiling java implementation and usage example:
import java.util.ArrayList;
import java.util.Random;
public class RandomWheel<T>
{
private static final class RandomWheelSection<T>
{
public double weight;
public T value;
public RandomWheelSection(double weight, T value)
{
this.weight = weight;
this.value = value;
}
}
private ArrayList<RandomWheelSection<T>> sections = new ArrayList<>();
private double totalWeight = 0;
private Random random = new Random();
public void addWheelSection(double weight, T value)
{
sections.add(new RandomWheelSection<T>(weight, value));
totalWeight += weight;
}
public T draw()
{
double rnd = totalWeight * random.nextDouble();
double sum = 0;
for (int i = 0; i < sections.size(); i++)
{
sum += sections.get(i).weight;
if (sum >= rnd)
return sections.get(i).value;
}
return sections.get(sections.size() - 1).value;
}
public static void main(String[] args)
{
RandomWheel<String> wheel = new RandomWheel<String>();
wheel.addWheelSection(1, "a");
wheel.addWheelSection(1, "b");
wheel.addWheelSection(2, "c");
wheel.addWheelSection(6, "d");
for (int i = 0; i < 100; i++)
System.out.print(wheel.draw());
}
}
While the selected answer works, it is unfortunately asymptotically slow for your use case. Instead of doing this, you could use something called Alias Sampling. Alias sampling (or alias method) is a technique used for selection of elements with a weighted distribution. If the weights of choosing those elements doesn't change you can do selection in O(1) time!. If this isn't the case, you can still get amortized O(1) time if the ratio between the number of selections you make and the changes you make to the alias table (changing the weights) is high. The current selected answer suggests an O(N) algorithm, the next best thing is O(log(N)) given sorted probabilities and binary search, but nothing is going to beat the O(1) time I suggested.
This site provides a good overview of Alias method that is mostly language agnostic. Essentially you create a table where each entry represents the outcome of two probabilities. There is a single threshold for each entry at the table, below the threshold you get one value, above you get another value. You spread larger probabilities across multiple table values in order to create a probability graph with an area of one for all probabilities combined.
Say you have the probabilities A, B, C, and D, which have the values 0.1, 0.1, 0.1 and 0.7 respectively. Alias method would spread the probability of 0.7 to all the others. One index would correspond to each probability, where you would have the 0.1 and 0.15 for ABC, and 0.25 for D's index. With this you normalize each probability so that you end up with 0.4 chance of getting A and 0.6 chance of getting D in A's index (0.1/(0.1 + 0.15) and 0.15/(0.1 + 0.15) respecively) as well as B and C's index, and 100% chance of getting D in D's index (0.25/0.25 is 1).
Given an unbiased uniform PRNG (Math.Random()) for indexing, you get an equal probability of choosing each index, but you also do a coin flip per index which provides the weighted probability. You have a 25% chance of landing on the A or D slot, but within that you only have a 40% chance of picking A, and 60% of D. .40 * .25 = 0.1, our original probability, and if you add up all of D's probabilities strewn through out the other indices, you would get .70 again.
So to do random selection, you need only to generate a random index from 0 to N, then do a coin flip, no matter how many items you add, this is very fast and constant cost. Making an alias table doesn't take that many lines of code either, my python version takes 80 lines including import statements and line breaks, and the version presented in the Pandas article is similarly sized (and it's C++)
For your java implementation one could map between probabilities and array list indices to your functions you must execute, creating an array of functions which are executed as you index to each, alternatively you could use function objects (functors) which have a method that you use to pass parameters in to execute.
ArrayList<(YourFunctionObject)> function_list;
// add functions
AliasSampler aliassampler = new AliasSampler(listOfProbabilities);
// somewhere later with some type T and some parameter values.
int index = aliassampler.sampleIndex();
T result = function_list[index].apply(parameters);
EDIT:
I've created a version in java of the AliasSampler method, using classes, this uses the sample index method and should be able to be used like above.
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
public class AliasSampler {
private ArrayList<Double> binaryProbabilityArray;
private ArrayList<Integer> aliasIndexList;
AliasSampler(ArrayList<Double> probabilities){
// java 8 needed here
assert(DoubleStream.of(probabilities).sum() == 1.0);
int n = probabilities.size();
// probabilityArray is the list of probabilities, this is the incoming probabilities scaled
// by the number of probabilities. This allows us to figure out which probabilities need to be spread
// to others since they are too large, ie [0.1 0.1 0.1 0.7] = [0.4 0.4 0.4 2.80]
ArrayList<Double> probabilityArray;
for(Double probability : probabilities){
probabilityArray.add(probability);
}
binaryProbabilityArray = new ArrayList<Double>(Collections.nCopies(n, 0.0));
aliasIndexList = new ArrayList<Integer>(Collections.nCopies(n, 0));
ArrayList<Integer> lessThanOneIndexList = new ArrayList<Integer>();
ArrayList<Integer> greaterThanOneIndexList = new ArrayList<Integer>();
for(int index = 0; index < probabilityArray.size(); index++){
double probability = probabilityArray.get(index);
if(probability < 1.0){
lessThanOneIndexList.add(index);
}
else{
greaterThanOneIndexList.add(index);
}
}
// while we still have indices to check for in each list, we attempt to spread the probability of those larger
// what this ends up doing in our first example is taking greater than one elements (2.80) and removing 0.6,
// and spreading it to different indices, so (((2.80 - 0.6) - 0.6) - 0.6) will equal 1.0, and the rest will
// be 0.4 + 0.6 = 1.0 as well.
while(lessThanOneIndexList.size() != 0 && greaterThanOneIndexList.size() != 0){
//https://stackoverflow.com/questions/16987727/removing-last-object-of-arraylist-in-java
// last element removal is equivalent to pop, java does this in constant time
int lessThanOneIndex = lessThanOneIndexList.remove(lessThanOneIndexList.size() - 1);
int greaterThanOneIndex = greaterThanOneIndexList.remove(greaterThanOneIndexList.size() - 1);
double probabilityLessThanOne = probabilityArray.get(lessThanOneIndex);
binaryProbabilityArray.set(lessThanOneIndex, probabilityLessThanOne);
aliasIndexList.set(lessThanOneIndex, greaterThanOneIndex);
probabilityArray.set(greaterThanOneIndex, probabilityArray.get(greaterThanOneIndex) + probabilityLessThanOne - 1);
if(probabilityArray.get(greaterThanOneIndex) < 1){
lessThanOneIndexList.add(greaterThanOneIndex);
}
else{
greaterThanOneIndexList.add(greaterThanOneIndex);
}
}
//if there are any probabilities left in either index list, they can't be spread across the other
//indicies, so they are set with probability 1.0. They still have the probabilities they should at this step, it works out mathematically.
while(greaterThanOneIndexList.size() != 0){
int greaterThanOneIndex = greaterThanOneIndexList.remove(greaterThanOneIndexList.size() - 1);
binaryProbabilityArray.set(greaterThanOneIndex, 1.0);
}
while(lessThanOneIndexList.size() != 0){
int lessThanOneIndex = lessThanOneIndexList.remove(lessThanOneIndexList.size() - 1);
binaryProbabilityArray.set(lessThanOneIndex, 1.0);
}
}
public int sampleIndex(){
int index = new Random().nextInt(binaryProbabilityArray.size());
double r = Math.random();
if( r < binaryProbabilityArray.get(index)){
return index;
}
else{
return aliasIndexList.get(index);
}
}
}
You could compute the cumulative probability for each class, pick a random number from [0; 1) and see where that number falls.
class WeightedRandomPicker {
private static Random random = new Random();
public static int choose(double[] probabilties) {
double randomVal = random.nextDouble();
double cumulativeProbability = 0;
for (int i = 0; i < probabilties.length; ++i) {
cumulativeProbability += probabilties[i];
if (randomVal < cumulativeProbability) {
return i;
}
}
return probabilties.length - 1; // to account for numerical errors
}
public static void main (String[] args) {
double[] probabilties = new double[]{0.1, 0.1, 0.2, 0.6}; // the final value is optional
for (int i = 0; i < 20; ++i) {
System.out.printf("%d\n", choose(probabilties));
}
}
}