aboutsummaryrefslogtreecommitdiff
path: root/std/rand
diff options
context:
space:
mode:
authorMarc Tiehuis <marctiehuis@gmail.com>2018-03-30 01:36:04 +1300
committerMarc Tiehuis <marctiehuis@gmail.com>2018-03-30 01:50:58 +1300
commit0fd0f6fd1f4b9e02cc33eac304b9f6db242cbb67 (patch)
treedc06675959dce091f7fa58c138d4b567cf2ba87f /std/rand
parent032fccf6151ff201ce4b8c7ab28ca460fed794c0 (diff)
downloadzig-0fd0f6fd1f4b9e02cc33eac304b9f6db242cbb67.tar.gz
zig-0fd0f6fd1f4b9e02cc33eac304b9f6db242cbb67.zip
Rewrite Rand functions
We now use a generic Rand structure which abstracts the core functions from the backing engine. The old Mersenne Twister engine is removed and replaced instead with three alternatives: - Pcg32 - Xoroshiro128+ - Isaac64 These should provide sufficient coverage for most purposes, including a CSPRNG using Isaac64. Consumers of the library that do not care about the actual engine implementation should use DefaultPrng and DefaultCsprng.
Diffstat (limited to 'std/rand')
-rw-r--r--std/rand/index.zig652
1 files changed, 652 insertions, 0 deletions
diff --git a/std/rand/index.zig b/std/rand/index.zig
new file mode 100644
index 0000000000..ce69565473
--- /dev/null
+++ b/std/rand/index.zig
@@ -0,0 +1,652 @@
+// The engines provided here should be initialized from an external source. For now, getRandomBytes
+// from the os package is the most suitable. Be sure to use a CSPRNG when required, otherwise using
+// a normal PRNG will be faster and use substantially less stack space.
+//
+// ```
+// var buf: [8]u8 = undefined;
+// try std.os.getRandomBytes(buf[0..]);
+// const seed = mem.readInt(buf[0..8], u64, builtin.Endian.Little);
+//
+// var r = DefaultPrng.init(seed);
+//
+// const s = r.random.scalar(u64);
+// ```
+//
+// TODO(tiehuis): Benchmark these against other reference implementations.
+
+const std = @import("../index.zig");
+const builtin = @import("builtin");
+const assert = std.debug.assert;
+const mem = std.mem;
+const math = std.math;
+
+// When you need fast unbiased random numbers
+pub const DefaultPrng = Xoroshiro128;
+
+// When you need cryptographically secure random numbers
+pub const DefaultCsprng = Isaac64;
+
+pub const Rand = struct {
+ fillFn: fn(r: &Rand, buf: []u8) void,
+
+ /// Read random bytes into the specified buffer until fill.
+ pub fn bytes(r: &Rand, buf: []u8) void {
+ r.fillFn(r, buf);
+ }
+
+ /// Return a random integer/boolean type.
+ pub fn scalar(r: &Rand, comptime T: type) T {
+ var rand_bytes: [@sizeOf(T)]u8 = undefined;
+ r.bytes(rand_bytes[0..]);
+
+ if (T == bool) {
+ return rand_bytes[0] & 0b1 == 0;
+ } else {
+ // NOTE: Cannot @bitCast array to integer type.
+ return mem.readInt(rand_bytes, T, builtin.Endian.Little);
+ }
+ }
+
+ /// Get a random unsigned integer with even distribution between `start`
+ /// inclusive and `end` exclusive.
+ pub fn range(r: &Rand, comptime T: type, start: T, end: T) T {
+ assert(start <= end);
+ if (T.is_signed) {
+ const uint = @IntType(false, T.bit_count);
+ if (start >= 0 and end >= 0) {
+ return T(r.range(uint, uint(start), uint(end)));
+ } else if (start < 0 and end < 0) {
+ // Can't overflow because the range is over signed ints
+ return math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1) catch unreachable;
+ } else if (start < 0 and end >= 0) {
+ const end_uint = uint(end);
+ const total_range = math.absCast(start) + end_uint;
+ const value = r.range(uint, 0, total_range);
+ const result = if (value < end_uint) x: {
+ break :x T(value);
+ } else if (value == end_uint) x: {
+ break :x start;
+ } else x: {
+ // Can't overflow because the range is over signed ints
+ break :x math.negateCast(value - end_uint) catch unreachable;
+ };
+ return result;
+ } else {
+ unreachable;
+ }
+ } else {
+ const total_range = end - start;
+ const leftover = @maxValue(T) % total_range;
+ const upper_bound = @maxValue(T) - leftover;
+ var rand_val_array: [@sizeOf(T)]u8 = undefined;
+
+ while (true) {
+ r.bytes(rand_val_array[0..]);
+ const rand_val = mem.readInt(rand_val_array, T, builtin.Endian.Little);
+ if (rand_val < upper_bound) {
+ return start + (rand_val % total_range);
+ }
+ }
+ }
+ }
+
+ /// Return a floating point value evenly distributed in the range [0, 1).
+ pub fn float(r: &Rand, comptime T: type) T {
+ // Generate a uniform value between [1, 2) and scale down to [0, 1).
+ // Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
+ switch (T) {
+ f32 => {
+ const s = r.scalar(u32);
+ const repr = (0x7f << 23) | (s >> 9);
+ return @bitCast(f32, repr) - 1.0;
+ },
+ f64 => {
+ const s = r.scalar(u64);
+ const repr = (0x3ff << 52) | (s >> 12);
+ return @bitCast(f64, repr) - 1.0;
+ },
+ else => @compileError("unknown floating point type"),
+ }
+ }
+
+ /// Return a floating point value normally distributed in the range [0, 1].
+ pub fn floatNorm(r: &Rand, comptime T: type) T {
+ // TODO(tiehuis): See https://www.doornik.com/research/ziggurat.pdf
+ @compileError("floatNorm is unimplemented");
+ }
+
+ /// Return a exponentially distributed float between (0, @maxValue(f64))
+ pub fn floatExp(r: &Rand, comptime T: type) T {
+ @compileError("floatExp is unimplemented");
+ }
+
+ /// Shuffle a slice into a random order.
+ pub fn shuffle(r: &Rand, comptime T: type, buf: []T) void {
+ if (buf.len < 2) {
+ return;
+ }
+
+ var i: usize = 0;
+ while (i < buf.len - 1) : (i += 1) {
+ const j = r.range(usize, i, buf.len);
+ mem.swap(T, &buf[i], &buf[j]);
+ }
+ }
+};
+
+// Generator to extend 64-bit seed values into longer sequences.
+//
+// The number of cycles is thus limited to 64-bits regardless of the engine, but this
+// is still plenty for practical purposes.
+const SplitMix64 = struct {
+ s: u64,
+
+ pub fn init(seed: u64) SplitMix64 {
+ return SplitMix64 { .s = seed };
+ }
+
+ pub fn next(self: &SplitMix64) u64 {
+ self.s +%= 0x9e3779b97f4a7c15;
+
+ var z = self.s;
+ z = (z ^ (z >> 30)) *% 0xbf58476d1ce4e5b9;
+ z = (z ^ (z >> 27)) *% 0x94d049bb133111eb;
+ return z ^ (z >> 31);
+ }
+};
+
+test "splitmix64 sequence" {
+ var r = SplitMix64.init(0xaeecf86f7878dd75);
+
+ const seq = []const u64 {
+ 0x5dbd39db0178eb44,
+ 0xa9900fb66b397da3,
+ 0x5c1a28b1aeebcf5c,
+ 0x64a963238f776912,
+ 0xc6d4177b21d1c0ab,
+ 0xb2cbdbdb5ea35394,
+ };
+
+ for (seq) |s| {
+ std.debug.assert(s == r.next());
+ }
+}
+
+// PCG32 - http://www.pcg-random.org/
+//
+// PRNG
+pub const Pcg = struct {
+ const default_multiplier = 6364136223846793005;
+
+ random: Rand,
+
+ s: u64,
+ i: u64,
+
+ pub fn init(init_s: u64) Pcg {
+ var pcg = Pcg {
+ .random = Rand { .fillFn = fill },
+ .s = undefined,
+ .i = undefined,
+ };
+
+ pcg.seed(init_s);
+ return pcg;
+ }
+
+ fn next(self: &Pcg) u32 {
+ const l = self.s;
+ self.s = l *% default_multiplier +% (self.i | 1);
+
+ const xor_s = @truncate(u32, ((l >> 18) ^ l) >> 27);
+ const rot = u32(l >> 59);
+
+ return (xor_s >> u5(rot)) | (xor_s << u5((0 -% rot) & 31));
+ }
+
+ fn seed(self: &Pcg, init_s: u64) void {
+ // Pcg requires 128-bits of seed.
+ var gen = SplitMix64.init(init_s);
+ self.seedTwo(gen.next(), gen.next());
+ }
+
+ fn seedTwo(self: &Pcg, init_s: u64, init_i: u64) void {
+ self.s = 0;
+ self.i = (init_s << 1) | 1;
+ self.s = self.s *% default_multiplier +% self.i;
+ self.s +%= init_i;
+ self.s = self.s *% default_multiplier +% self.i;
+ }
+
+ fn fill(r: &Rand, buf: []u8) void {
+ const self = @fieldParentPtr(Pcg, "random", r);
+
+ var i: usize = 0;
+ const aligned_len = buf.len - (buf.len & 7);
+
+ // Complete 4 byte segments.
+ while (i < aligned_len) : (i += 4) {
+ var n = self.next();
+ comptime var j: usize = 0;
+ inline while (j < 4) : (j += 1) {
+ buf[i + j] = @truncate(u8, n);
+ n >>= 8;
+ }
+ }
+
+ // Remaining. (cuts the stream)
+ if (i != buf.len) {
+ var n = self.next();
+ while (i < buf.len) : (i += 1) {
+ buf[i] = @truncate(u8, n);
+ n >>= 4;
+ }
+ }
+ }
+};
+
+test "pcg sequence" {
+ var r = Pcg.init(0);
+ const s0: u64 = 0x9394bf54ce5d79de;
+ const s1: u64 = 0x84e9c579ef59bbf7;
+ r.seedTwo(s0, s1);
+
+ const seq = []const u32 {
+ 2881561918,
+ 3063928540,
+ 1199791034,
+ 2487695858,
+ 1479648952,
+ 3247963454,
+ };
+
+ for (seq) |s| {
+ std.debug.assert(s == r.next());
+ }
+}
+
+// Xoroshiro128+ - http://xoroshiro.di.unimi.it/
+//
+// PRNG
+pub const Xoroshiro128 = struct {
+ random: Rand,
+
+ s: [2]u64,
+
+ pub fn init(init_s: u64) Xoroshiro128 {
+ var x = Xoroshiro128 {
+ .random = Rand { .fillFn = fill },
+ .s = undefined,
+ };
+
+ x.seed(init_s);
+ return x;
+ }
+
+ fn next(self: &Xoroshiro128) u64 {
+ const s0 = self.s[0];
+ var s1 = self.s[1];
+ const r = s0 +% s1;
+
+ s1 ^= s0;
+ self.s[0] = math.rotl(u64, s0, u8(55)) ^ s1 ^ (s1 << 14);
+ self.s[1] = math.rotl(u64, s1, u8(36));
+
+ return r;
+ }
+
+ // Skip 2^64 places ahead in the sequence
+ fn jump(self: &Xoroshiro128) void {
+ var s0: u64 = 0;
+ var s1: u64 = 0;
+
+ const table = []const u64 {
+ 0xbeac0467eba5facb,
+ 0xd86b048b86aa9922
+ };
+
+ inline for (table) |entry| {
+ var b: usize = 0;
+ while (b < 64) : (b += 1) {
+ if ((entry & (u64(1) << u6(b))) != 0) {
+ s0 ^= self.s[0];
+ s1 ^= self.s[1];
+ }
+ _ = self.next();
+ }
+ }
+
+ self.s[0] = s0;
+ self.s[1] = s1;
+ }
+
+ fn seed(self: &Xoroshiro128, init_s: u64) void {
+ // Xoroshiro requires 128-bits of seed.
+ var gen = SplitMix64.init(init_s);
+
+ self.s[0] = gen.next();
+ self.s[1] = gen.next();
+ }
+
+ fn fill(r: &Rand, buf: []u8) void {
+ const self = @fieldParentPtr(Xoroshiro128, "random", r);
+
+ var i: usize = 0;
+ const aligned_len = buf.len - (buf.len & 7);
+
+ // Complete 8 byte segments.
+ while (i < aligned_len) : (i += 8) {
+ var n = self.next();
+ comptime var j: usize = 0;
+ inline while (j < 8) : (j += 1) {
+ buf[i + j] = @truncate(u8, n);
+ n >>= 8;
+ }
+ }
+
+ // Remaining. (cuts the stream)
+ if (i != buf.len) {
+ var n = self.next();
+ while (i < buf.len) : (i += 1) {
+ buf[i] = @truncate(u8, n);
+ n >>= 8;
+ }
+ }
+ }
+};
+
+test "xoroshiro sequence" {
+ var r = Xoroshiro128.init(0);
+ r.s[0] = 0xaeecf86f7878dd75;
+ r.s[1] = 0x01cd153642e72622;
+
+ const seq1 = []const u64 {
+ 0xb0ba0da5bb600397,
+ 0x18a08afde614dccc,
+ 0xa2635b956a31b929,
+ 0xabe633c971efa045,
+ 0x9ac19f9706ca3cac,
+ 0xf62b426578c1e3fb,
+ };
+
+ for (seq1) |s| {
+ std.debug.assert(s == r.next());
+ }
+
+
+ r.jump();
+
+ const seq2 = []const u64 {
+ 0x95344a13556d3e22,
+ 0xb4fb32dafa4d00df,
+ 0xb2011d9ccdcfe2dd,
+ 0x05679a9b2119b908,
+ 0xa860a1da7c9cd8a0,
+ 0x658a96efe3f86550,
+ };
+
+ for (seq2) |s| {
+ std.debug.assert(s == r.next());
+ }
+}
+
+// ISAAC64 - http://www.burtleburtle.net/bob/rand/isaacafa.html
+//
+// CSPRNG
+//
+// Follows the general idea of the implementation from here with a few shortcuts.
+// https://doc.rust-lang.org/rand/src/rand/prng/isaac64.rs.html
+pub const Isaac64 = struct {
+ random: Rand,
+
+ r: [256]u64,
+ m: [256]u64,
+ a: u64,
+ b: u64,
+ c: u64,
+ i: usize,
+
+ pub fn init(init_s: u64) Isaac64 {
+ var isaac = Isaac64 {
+ .random = Rand { .fillFn = fill },
+ .r = undefined,
+ .m = undefined,
+ .a = undefined,
+ .b = undefined,
+ .c = undefined,
+ .i = undefined,
+ };
+
+ // seed == 0 => same result as the unseeded reference implementation
+ isaac.seed(init_s, 1);
+ return isaac;
+ }
+
+ fn step(self: &Isaac64, mix: u64, base: usize, comptime m1: usize, comptime m2: usize) void {
+ const x = self.m[base + m1];
+ self.a = mix +% self.m[base + m2];
+
+ const y = self.a +% self.b +% self.m[(x >> 3) % self.m.len];
+ self.m[base + m1] = y;
+
+ self.b = x +% self.m[(y >> 11) % self.m.len];
+ self.r[self.r.len - 1 - base - m1] = self.b;
+ }
+
+ fn refill(self: &Isaac64) void {
+ const midpoint = self.r.len / 2;
+
+ self.c +%= 1;
+ self.b +%= self.c;
+
+ {
+ var i: usize = 0;
+ while (i < midpoint) : (i += 4) {
+ self.step( ~(self.a ^ (self.a << 21)), i + 0, 0, midpoint);
+ self.step( self.a ^ (self.a >> 5) , i + 1, 0, midpoint);
+ self.step( self.a ^ (self.a << 12) , i + 2, 0, midpoint);
+ self.step( self.a ^ (self.a >> 33) , i + 3, 0, midpoint);
+ }
+ }
+
+ {
+ var i: usize = 0;
+ while (i < midpoint) : (i += 4) {
+ self.step( ~(self.a ^ (self.a << 21)), i + 0, midpoint, 0);
+ self.step( self.a ^ (self.a >> 5) , i + 1, midpoint, 0);
+ self.step( self.a ^ (self.a << 12) , i + 2, midpoint, 0);
+ self.step( self.a ^ (self.a >> 33) , i + 3, midpoint, 0);
+ }
+ }
+
+ self.i = 0;
+ }
+
+ fn next(self: &Isaac64) u64 {
+ if (self.i >= self.r.len) {
+ self.refill();
+ }
+
+ const value = self.r[self.i];
+ self.i += 1;
+ return value;
+ }
+
+ fn seed(self: &Isaac64, init_s: u64, comptime rounds: usize) void {
+ // We ignore the multi-pass requirement since we don't currently expose full access to
+ // seeding the self.m array completely.
+ mem.set(u64, self.m[0..], 0);
+ self.m[0] = init_s;
+
+ // prescrambled golden ratio constants
+ var a = []const u64 {
+ 0x647c4677a2884b7c,
+ 0xb9f8b322c73ac862,
+ 0x8c0ea5053d4712a0,
+ 0xb29b2e824a595524,
+ 0x82f053db8355e0ce,
+ 0x48fe4a0fa5a09315,
+ 0xae985bf2cbfc89ed,
+ 0x98f5704f6c44c0ab,
+ };
+
+ comptime var i: usize = 0;
+ inline while (i < rounds) : (i += 1) {
+ var j: usize = 0;
+ while (j < self.m.len) : (j += 8) {
+ comptime var x1: usize = 0;
+ inline while (x1 < 8) : (x1 += 1) {
+ a[x1] +%= self.m[j + x1];
+ }
+
+ a[0] -%= a[4]; a[5] ^= a[7] >> 9; a[7] +%= a[0];
+ a[1] -%= a[5]; a[6] ^= a[0] << 9; a[0] +%= a[1];
+ a[2] -%= a[6]; a[7] ^= a[1] >> 23; a[1] +%= a[2];
+ a[3] -%= a[7]; a[0] ^= a[2] << 15; a[2] +%= a[3];
+ a[4] -%= a[0]; a[1] ^= a[3] >> 14; a[3] +%= a[4];
+ a[5] -%= a[1]; a[2] ^= a[4] << 20; a[4] +%= a[5];
+ a[6] -%= a[2]; a[3] ^= a[5] >> 17; a[5] +%= a[6];
+ a[7] -%= a[3]; a[4] ^= a[6] << 14; a[6] +%= a[7];
+
+ comptime var x2: usize = 0;
+ inline while (x2 < 8) : (x2 += 1) {
+ self.m[j + x2] = a[x2];
+ }
+ }
+ }
+
+ mem.set(u64, self.r[0..], 0);
+ self.a = 0;
+ self.b = 0;
+ self.c = 0;
+ self.i = self.r.len; // trigger refill on first value
+ }
+
+ fn fill(r: &Rand, buf: []u8) void {
+ const self = @fieldParentPtr(Isaac64, "random", r);
+
+ var i: usize = 0;
+ const aligned_len = buf.len - (buf.len & 7);
+
+ // Fill complete 64-byte segments
+ while (i < aligned_len) : (i += 8) {
+ var n = self.next();
+ comptime var j: usize = 0;
+ inline while (j < 8) : (j += 1) {
+ buf[i + j] = @truncate(u8, n);
+ n >>= 8;
+ }
+ }
+
+ // Fill trailing, ignoring excess (cut the stream).
+ if (i != buf.len) {
+ var n = self.next();
+ while (i < buf.len) : (i += 1) {
+ buf[i] = @truncate(u8, n);
+ n >>= 8;
+ }
+ }
+ }
+};
+
+test "isaac64 sequence" {
+ var r = Isaac64.init(0);
+
+ // from reference implementation
+ const seq = []const u64 {
+ 0xf67dfba498e4937c,
+ 0x84a5066a9204f380,
+ 0xfee34bd5f5514dbb,
+ 0x4d1664739b8f80d6,
+ 0x8607459ab52a14aa,
+ 0x0e78bc5a98529e49,
+ 0xfe5332822ad13777,
+ 0x556c27525e33d01a,
+ 0x08643ca615f3149f,
+ 0xd0771faf3cb04714,
+ 0x30e86f68a37b008d,
+ 0x3074ebc0488a3adf,
+ 0x270645ea7a2790bc,
+ 0x5601a0a8d3763c6a,
+ 0x2f83071f53f325dd,
+ 0xb9090f3d42d2d2ea,
+ };
+
+ for (seq) |s| {
+ std.debug.assert(s == r.next());
+ }
+}
+
+// Actual Rand helper function tests, pcg engine is assumed correct.
+test "Rand float" {
+ var prng = DefaultPrng.init(0);
+
+ var i: usize = 0;
+ while (i < 1000) : (i += 1) {
+ const val1 = prng.random.float(f32);
+ std.debug.assert(val1 >= 0.0);
+ std.debug.assert(val1 < 1.0);
+
+ const val2 = prng.random.float(f64);
+ std.debug.assert(val2 >= 0.0);
+ std.debug.assert(val2 < 1.0);
+ }
+}
+
+test "Rand scalar" {
+ var prng = DefaultPrng.init(0);
+ const s = prng .random.scalar(u64);
+}
+
+test "Rand bytes" {
+ var prng = DefaultPrng.init(0);
+ var buf: [2048]u8 = undefined;
+ prng.random.bytes(buf[0..]);
+}
+
+test "Rand shuffle" {
+ var prng = DefaultPrng.init(0);
+
+ var seq = []const u8 { 0, 1, 2, 3, 4 };
+ var seen = []bool {false} ** 5;
+
+ var i: usize = 0;
+ while (i < 1000) : (i += 1) {
+ prng.random.shuffle(u8, seq[0..]);
+ seen[seq[0]] = true;
+ std.debug.assert(sumArray(seq[0..]) == 10);
+ }
+
+ // we should see every entry at the head at least once
+ for (seen) |e| {
+ std.debug.assert(e == true);
+ }
+}
+
+fn sumArray(s: []const u8) u32 {
+ var r: u32 = 0;
+ for (s) |e| r += e;
+ return r;
+}
+
+test "Rand range" {
+ var prng = DefaultPrng.init(0);
+ testRange(&prng.random, -4, 3);
+ testRange(&prng.random, -4, -1);
+ testRange(&prng.random, 10, 14);
+}
+
+fn testRange(r: &Rand, start: i32, end: i32) void {
+ const count = usize(end - start);
+ var values_buffer = []bool{false} ** 20;
+ const values = values_buffer[0..count];
+ var i: usize = 0;
+ while (i < count) {
+ const value = r.range(i32, start, end);
+ const index = usize(value - start);
+ if (!values[index]) {
+ i += 1;
+ values[index] = true;
+ }
+ }
+}