diff options
| author | Josh Wolfe <thejoshwolfe@gmail.com> | 2018-09-27 00:35:38 -0400 |
|---|---|---|
| committer | Andrew Kelley <superjoe30@gmail.com> | 2018-09-27 00:35:38 -0400 |
| commit | e7d9d00ac8c2bb075d0f67842b58718aececdc09 (patch) | |
| tree | a9debecbf1bcfeca594c7f69557697a1e76c5c1e /std | |
| parent | 1c26c2f4d5c9029fe39cc413f75e547636e42a14 (diff) | |
| download | zig-e7d9d00ac8c2bb075d0f67842b58718aececdc09.tar.gz zig-e7d9d00ac8c2bb075d0f67842b58718aececdc09.zip | |
overhaul api for getting random integers (#1578)
* rand api overhaul
* no retry limits. instead documented a recommendation
to call int(T) % len directly.
Diffstat (limited to 'std')
| -rw-r--r-- | std/rand/index.zig | 333 |
1 files changed, 267 insertions, 66 deletions
diff --git a/std/rand/index.zig b/std/rand/index.zig index 2cbff049ea..6cad865d1e 100644 --- a/std/rand/index.zig +++ b/std/rand/index.zig @@ -5,11 +5,11 @@ // ``` // var buf: [8]u8 = undefined; // try std.os.getRandomBytes(buf[0..]); -// const seed = mem.readInt(buf[0..8], u64, builtin.Endian.Little); +// const seed = mem.readIntLE(u64, buf[0..8]); // // var r = DefaultPrng.init(seed); // -// const s = r.random.scalar(u64); +// const s = r.random.int(u64); // ``` // // TODO(tiehuis): Benchmark these against other reference implementations. @@ -35,60 +35,117 @@ pub const Random = struct { r.fillFn(r, buf); } - /// Return a random integer/boolean type. - pub fn scalar(r: *Random, comptime T: type) T { - var rand_bytes: [@sizeOf(T)]u8 = undefined; + pub fn boolean(r: *Random) bool { + return r.int(u1) != 0; + } + + /// Returns a random int `i` such that `0 <= i <= @maxValue(T)`. + /// `i` is evenly distributed. + pub fn int(r: *Random, comptime T: type) T { + const UnsignedT = @IntType(false, T.bit_count); + const ByteAlignedT = @IntType(false, @divTrunc(T.bit_count + 7, 8) * 8); + + var rand_bytes: [@sizeOf(ByteAlignedT)]u8 = undefined; r.bytes(rand_bytes[0..]); - if (T == bool) { - return rand_bytes[0] & 0b1 == 0; + // use LE instead of native endian for better portability maybe? + // TODO: endian portability is pointless if the underlying prng isn't endian portable. + // TODO: document the endian portability of this library. + const byte_aligned_result = mem.readIntLE(ByteAlignedT, rand_bytes); + const unsigned_result = @truncate(UnsignedT, byte_aligned_result); + return @bitCast(T, unsigned_result); + } + + /// Returns an evenly distributed random unsigned integer `0 <= i < less_than`. + /// This function assumes that the underlying ::fillFn produces evenly distributed values. + /// Within this assumption, the runtime of this function is exponentially distributed. + /// If ::fillFn were backed by a true random generator, + /// the runtime of this function would technically be unbounded. + /// However, if ::fillFn is backed by any evenly distributed pseudo random number generator, + /// this function is guaranteed to return. + /// If you need deterministic runtime bounds, consider instead using `r.int(T) % less_than`, + /// which will usually be biased toward smaller values. + pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T { + assert(T.is_signed == false); + assert(0 < less_than); + + const last_group_size_minus_one: T = @maxValue(T) % less_than; + if (last_group_size_minus_one == less_than - 1) { + // less_than is a power of two. + assert(math.floorPowerOfTwo(T, less_than) == less_than); + // There is no retry zone. The optimal retry_zone_start would be @maxValue(T) + 1. + return r.int(T) % less_than; + } + const retry_zone_start = @maxValue(T) - last_group_size_minus_one; + + while (true) { + const rand_val = r.int(T); + if (rand_val < retry_zone_start) { + return rand_val % less_than; + } + } + } + + /// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`. + /// See ::uintLessThan, which this function uses in most cases, + /// for commentary on the runtime of this function. + pub fn uintAtMost(r: *Random, comptime T: type, at_most: T) T { + assert(T.is_signed == false); + if (at_most == @maxValue(T)) { + // have the full range + return r.int(T); + } + return r.uintLessThan(T, at_most + 1); + } + + /// Returns an evenly distributed random integer `at_least <= i < less_than`. + /// See ::uintLessThan, which this function uses in most cases, + /// for commentary on the runtime of this function. + pub fn intRangeLessThan(r: *Random, comptime T: type, at_least: T, less_than: T) T { + assert(at_least < less_than); + if (T.is_signed) { + // Two's complement makes this math pretty easy. + const UnsignedT = @IntType(false, T.bit_count); + const lo = @bitCast(UnsignedT, at_least); + const hi = @bitCast(UnsignedT, less_than); + const result = lo +% r.uintLessThan(UnsignedT, hi -% lo); + return @bitCast(T, result); + } else { + // The signed implementation would work fine, but we can use stricter arithmetic operators here. + return at_least + r.uintLessThan(T, less_than - at_least); + } + } + + /// Returns an evenly distributed random integer `at_least <= i <= at_most`. + /// See ::uintLessThan, which this function uses in most cases, + /// for commentary on the runtime of this function. + pub fn intRangeAtMost(r: *Random, comptime T: type, at_least: T, at_most: T) T { + assert(at_least <= at_most); + if (T.is_signed) { + // Two's complement makes this math pretty easy. + const UnsignedT = @IntType(false, T.bit_count); + const lo = @bitCast(UnsignedT, at_least); + const hi = @bitCast(UnsignedT, at_most); + const result = lo +% r.uintAtMost(UnsignedT, hi -% lo); + return @bitCast(T, result); } else { - // NOTE: Cannot @bitCast array to integer type. - return mem.readInt(rand_bytes, T, builtin.Endian.Little); + // The signed implementation would work fine, but we can use stricter arithmetic operators here. + return at_least + r.uintAtMost(T, at_most - at_least); } } + /// Return a random integer/boolean type. + /// TODO: deprecated. use ::boolean or ::int instead. + pub fn scalar(r: *Random, comptime T: type) T { + if (T == bool) return r.boolean(); + return r.int(T); + } + /// Return a random integer with even distribution between `start` /// inclusive and `end` exclusive. `start` must be less than `end`. + /// TODO: deprecated. renamed to ::intRangeLessThan pub fn range(r: *Random, comptime T: type, start: T, end: T) T { - assert(start < end); - if (T.is_signed) { - const uint = @IntType(false, T.bit_count); - if (start >= 0 and end >= 0) { - return @intCast(T, r.range(uint, @intCast(uint, start), @intCast(uint, end))); - } else if (start < 0 and end < 0) { - // Can't overflow because the range is over signed ints - return math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1) catch unreachable; - } else if (start < 0 and end >= 0) { - const end_uint = @intCast(uint, end); - const total_range = math.absCast(start) + end_uint; - const value = r.range(uint, 0, total_range); - const result = if (value < end_uint) x: { - break :x @intCast(T, value); - } else if (value == end_uint) x: { - break :x start; - } else x: { - // Can't overflow because the range is over signed ints - break :x math.negateCast(value - end_uint) catch unreachable; - }; - return result; - } else { - unreachable; - } - } else { - const total_range = end - start; - const leftover = @maxValue(T) % total_range; - const upper_bound = @maxValue(T) - leftover; - var rand_val_array: [@sizeOf(T)]u8 = undefined; - - while (true) { - r.bytes(rand_val_array[0..]); - const rand_val = mem.readInt(rand_val_array, T, builtin.Endian.Little); - if (rand_val < upper_bound) { - return start + (rand_val % total_range); - } - } - } + return r.intRangeLessThan(T, start, end); } /// Return a floating point value evenly distributed in the range [0, 1). @@ -97,12 +154,12 @@ pub const Random = struct { // Note: The lowest mantissa bit is always set to 0 so we only use half the available range. switch (T) { f32 => { - const s = r.scalar(u32); + const s = r.int(u32); const repr = (0x7f << 23) | (s >> 9); return @bitCast(f32, repr) - 1.0; }, f64 => { - const s = r.scalar(u64); + const s = r.int(u64); const repr = (0x3ff << 52) | (s >> 12); return @bitCast(f64, repr) - 1.0; }, @@ -142,12 +199,167 @@ pub const Random = struct { var i: usize = 0; while (i < buf.len - 1) : (i += 1) { - const j = r.range(usize, i, buf.len); + const j = r.intRangeLessThan(usize, i, buf.len); mem.swap(T, &buf[i], &buf[j]); } } }; +const SequentialPrng = struct { + const Self = @This(); + random: Random, + next_value: u8, + + pub fn init() Self { + return Self{ + .random = Random{ .fillFn = fill }, + .next_value = 0, + }; + } + + fn fill(r: *Random, buf: []u8) void { + const self = @fieldParentPtr(Self, "random", r); + for (buf) |*b| { + b.* = self.next_value; + } + self.next_value +%= 1; + } +}; + +test "Random int" { + testRandomInt(); + comptime testRandomInt(); +} +fn testRandomInt() void { + var r = SequentialPrng.init(); + + assert(r.random.int(u0) == 0); + + r.next_value = 0; + assert(r.random.int(u1) == 0); + assert(r.random.int(u1) == 1); + assert(r.random.int(u2) == 2); + assert(r.random.int(u2) == 3); + assert(r.random.int(u2) == 0); + + r.next_value = 0xff; + assert(r.random.int(u8) == 0xff); + r.next_value = 0x11; + assert(r.random.int(u8) == 0x11); + + r.next_value = 0xff; + assert(r.random.int(u32) == 0xffffffff); + r.next_value = 0x11; + assert(r.random.int(u32) == 0x11111111); + + r.next_value = 0xff; + assert(r.random.int(i32) == -1); + r.next_value = 0x11; + assert(r.random.int(i32) == 0x11111111); + + r.next_value = 0xff; + assert(r.random.int(i8) == -1); + r.next_value = 0x11; + assert(r.random.int(i8) == 0x11); + + r.next_value = 0xff; + assert(r.random.int(u33) == 0x1ffffffff); + r.next_value = 0xff; + assert(r.random.int(i1) == -1); + r.next_value = 0xff; + assert(r.random.int(i2) == -1); + r.next_value = 0xff; + assert(r.random.int(i33) == -1); +} + +test "Random boolean" { + testRandomBoolean(); + comptime testRandomBoolean(); +} +fn testRandomBoolean() void { + var r = SequentialPrng.init(); + assert(r.random.boolean() == false); + assert(r.random.boolean() == true); + assert(r.random.boolean() == false); + assert(r.random.boolean() == true); +} + +test "Random intLessThan" { + @setEvalBranchQuota(10000); + testRandomIntLessThan(); + comptime testRandomIntLessThan(); +} +fn testRandomIntLessThan() void { + var r = SequentialPrng.init(); + r.next_value = 0xff; + assert(r.random.uintLessThan(u8, 4) == 3); + r.next_value = 0xff; + assert(r.random.uintLessThan(u8, 3) == 0); + assert(r.next_value == 1); + + r.next_value = 0xff; + assert(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f); + r.next_value = 0xff; + assert(r.random.intRangeLessThan(u8, 0x7f, 0xff) == 0xfe); + + r.next_value = 0xff; + assert(r.random.intRangeLessThan(i8, 0, 0x40) == 0x3f); + r.next_value = 0xff; + assert(r.random.intRangeLessThan(i8, -0x40, 0x40) == 0x3f); + r.next_value = 0xff; + assert(r.random.intRangeLessThan(i8, -0x80, 0) == -1); + + r.next_value = 0xff; + assert(r.random.intRangeLessThan(i64, -0x8000000000000000, 0) == -1); + r.next_value = 0xff; + assert(r.random.intRangeLessThan(i3, -4, 0) == -1); + r.next_value = 0xff; + assert(r.random.intRangeLessThan(i3, -2, 2) == 1); + + // test retrying and eventually getting a good value + // start just out of bounds + r.next_value = 0x81; + assert(r.random.uintLessThan(u8, 0x81) == 0); +} + +test "Random intAtMost" { + @setEvalBranchQuota(10000); + testRandomIntAtMost(); + comptime testRandomIntAtMost(); +} +fn testRandomIntAtMost() void { + var r = SequentialPrng.init(); + r.next_value = 0xff; + assert(r.random.uintAtMost(u8, 3) == 3); + r.next_value = 0xff; + assert(r.random.uintAtMost(u8, 2) == 0); + assert(r.next_value == 1); + + r.next_value = 0xff; + assert(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f); + r.next_value = 0xff; + assert(r.random.intRangeAtMost(u8, 0x7f, 0xfe) == 0xfe); + + r.next_value = 0xff; + assert(r.random.intRangeAtMost(i8, 0, 0x3f) == 0x3f); + r.next_value = 0xff; + assert(r.random.intRangeAtMost(i8, -0x40, 0x3f) == 0x3f); + r.next_value = 0xff; + assert(r.random.intRangeAtMost(i8, -0x80, -1) == -1); + + r.next_value = 0xff; + assert(r.random.intRangeAtMost(i64, -0x8000000000000000, -1) == -1); + r.next_value = 0xff; + assert(r.random.intRangeAtMost(i3, -4, -1) == -1); + r.next_value = 0xff; + assert(r.random.intRangeAtMost(i3, -2, 1) == 1); + + // test retrying and eventually getting a good value + // start just out of bounds + r.next_value = 0x81; + assert(r.random.uintAtMost(u8, 0x80) == 0); +} + // Generator to extend 64-bit seed values into longer sequences. // // The number of cycles is thus limited to 64-bits regardless of the engine, but this @@ -622,17 +834,6 @@ test "Random float" { } } -test "Random scalar" { - var prng = DefaultPrng.init(0); - const s = prng.random.scalar(u64); -} - -test "Random bytes" { - var prng = DefaultPrng.init(0); - var buf: [2048]u8 = undefined; - prng.random.bytes(buf[0..]); -} - test "Random shuffle" { var prng = DefaultPrng.init(0); @@ -664,16 +865,16 @@ test "Random range" { testRange(&prng.random, -4, 3); testRange(&prng.random, -4, -1); testRange(&prng.random, 10, 14); - // TODO: test that prng.random.range(1, 1) causes an assertion error + testRange(&prng.random, -0x80, 0x7f); } -fn testRange(r: *Random, start: i32, end: i32) void { - const count = @intCast(usize, end - start); - var values_buffer = []bool{false} ** 20; +fn testRange(r: *Random, start: i8, end: i8) void { + const count = @intCast(usize, i32(end) - i32(start)); + var values_buffer = []bool{false} ** 0x100; const values = values_buffer[0..count]; var i: usize = 0; while (i < count) { - const value = r.range(i32, start, end); + const value: i32 = r.intRangeLessThan(i8, start, end); const index = @intCast(usize, value - start); if (!values[index]) { i += 1; |
