From 1924ffa67d90f21852680b7c9a09df07fc218ebe Mon Sep 17 00:00:00 2001 From: Josh Wolfe Date: Wed, 21 Nov 2018 17:33:37 -0500 Subject: better debiased random range implementation --- std/rand/index.zig | 78 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 32 deletions(-) (limited to 'std') diff --git a/std/rand/index.zig b/std/rand/index.zig index bb607a067e..0d9e58fd87 100644 --- a/std/rand/index.zig +++ b/std/rand/index.zig @@ -69,22 +69,36 @@ pub const Random = struct { pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T { assert(T.is_signed == false); assert(0 < less_than); - - const last_group_size_minus_one: T = maxInt(T) % less_than; - if (last_group_size_minus_one == less_than - 1) { - // less_than is a power of two. - assert(math.floorPowerOfTwo(T, less_than) == less_than); - // There is no retry zone. The optimal retry_zone_start would be maxInt(T) + 1. - return r.int(T) % less_than; - } - const retry_zone_start = maxInt(T) - last_group_size_minus_one; - - while (true) { - const rand_val = r.int(T); - if (rand_val < retry_zone_start) { - return rand_val % less_than; + // Small is typically u32 + const Small = @IntType(false, @divTrunc(T.bit_count + 31, 32) * 32); + // Large is typically u64 + const Large = @IntType(false, Small.bit_count * 2); + + // adapted from: + // http://www.pcg-random.org/posts/bounded-rands.html + // "Lemire's (with an extra tweak from me)" + var x: Small = r.int(Small); + var m: Large = Large(x) * Large(less_than); + var l: Small = @truncate(Small, m); + if (l < less_than) { + // TODO: workaround for https://github.com/ziglang/zig/issues/1770 + // should be: + // var t: Small = -%less_than; + var t: Small = @bitCast(Small, -%@bitCast(@IntType(true, Small.bit_count), Small(less_than))); + + if (t >= less_than) { + t -= less_than; + if (t >= less_than) { + t %= less_than; + } + } + while (l < t) { + x = r.int(Small); + m = Large(x) * Large(less_than); + l = @truncate(Small, m); } } + return @intCast(T, m >> Small.bit_count); } /// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`. @@ -294,10 +308,19 @@ fn testRandomIntLessThan() void { var r = SequentialPrng.init(); r.next_value = 0xff; assert(r.random.uintLessThan(u8, 4) == 3); - r.next_value = 0xff; - assert(r.random.uintLessThan(u8, 3) == 0); + assert(r.next_value == 0); + assert(r.random.uintLessThan(u8, 4) == 0); assert(r.next_value == 1); + r.next_value = 0; + assert(r.random.uintLessThan(u64, 32) == 0); + + // trigger the bias rejection code path + r.next_value = 0; + assert(r.random.uintLessThan(u8, 3) == 0); + // verify we incremented twice + assert(r.next_value == 2); + r.next_value = 0xff; assert(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f); r.next_value = 0xff; @@ -310,17 +333,10 @@ fn testRandomIntLessThan() void { r.next_value = 0xff; assert(r.random.intRangeLessThan(i8, -0x80, 0) == -1); - r.next_value = 0xff; - assert(r.random.intRangeLessThan(i64, -0x8000000000000000, 0) == -1); r.next_value = 0xff; assert(r.random.intRangeLessThan(i3, -4, 0) == -1); r.next_value = 0xff; assert(r.random.intRangeLessThan(i3, -2, 2) == 1); - - // test retrying and eventually getting a good value - // start just out of bounds - r.next_value = 0x81; - assert(r.random.uintLessThan(u8, 0x81) == 0); } test "Random intAtMost" { @@ -332,9 +348,14 @@ fn testRandomIntAtMost() void { var r = SequentialPrng.init(); r.next_value = 0xff; assert(r.random.uintAtMost(u8, 3) == 3); - r.next_value = 0xff; + assert(r.next_value == 0); + assert(r.random.uintAtMost(u8, 3) == 0); + + // trigger the bias rejection code path + r.next_value = 0; assert(r.random.uintAtMost(u8, 2) == 0); - assert(r.next_value == 1); + // verify we incremented twice + assert(r.next_value == 2); r.next_value = 0xff; assert(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f); @@ -348,17 +369,10 @@ fn testRandomIntAtMost() void { r.next_value = 0xff; assert(r.random.intRangeAtMost(i8, -0x80, -1) == -1); - r.next_value = 0xff; - assert(r.random.intRangeAtMost(i64, -0x8000000000000000, -1) == -1); r.next_value = 0xff; assert(r.random.intRangeAtMost(i3, -4, -1) == -1); r.next_value = 0xff; assert(r.random.intRangeAtMost(i3, -2, 1) == 1); - - // test retrying and eventually getting a good value - // start just out of bounds - r.next_value = 0x81; - assert(r.random.uintAtMost(u8, 0x80) == 0); } // Generator to extend 64-bit seed values into longer sequences. -- cgit v1.2.3