aboutsummaryrefslogtreecommitdiff
path: root/std/atomic/queue.zig
blob: 3866bad7ce74acc14c113b0789b9fbbae3591343 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
const builtin = @import("builtin");
const AtomicOrder = builtin.AtomicOrder;
const AtomicRmwOp = builtin.AtomicRmwOp;

/// Many reader, many writer, non-allocating, thread-safe, lock-free
pub fn Queue(comptime T: type) type {
    return struct {
        head: &Node,
        tail: &Node,
        root: Node,

        pub const Self = this;

        pub const Node = struct {
            next: ?&Node,
            data: T,
        };

        // TODO: well defined copy elision: https://github.com/zig-lang/zig/issues/287
        pub fn init(self: &Self) void {
            self.root.next = null;
            self.head = &self.root;
            self.tail = &self.root;
        }

        pub fn put(self: &Self, node: &Node) void {
            node.next = null;

            const tail = @atomicRmw(&Node, &self.tail, AtomicRmwOp.Xchg, node, AtomicOrder.SeqCst);
            _ = @atomicRmw(?&Node, &tail.next, AtomicRmwOp.Xchg, node, AtomicOrder.SeqCst);
        }

        pub fn get(self: &Self) ?&Node {
            var head = @atomicLoad(&Node, &self.head, AtomicOrder.Acquire);
            while (true) {
                const node = head.next ?? return null;
                head = @cmpxchgWeak(&Node, &self.head, head, node, AtomicOrder.Release, AtomicOrder.Acquire) ?? return node;
            }
        }
    };
}

const std = @import("std");
const Context = struct {
    allocator: &std.mem.Allocator,
    queue: &Queue(i32),
    put_sum: isize,
    get_sum: isize,
    get_count: usize,
    puts_done: u8, // TODO make this a bool
};
const puts_per_thread = 10000;
const put_thread_count = 3;

test "std.atomic.queue" {
    if (builtin.os != builtin.Os.linux) {
        // TODO implement kernel threads for windows and macos
        return;
    }
    var direct_allocator = std.heap.DirectAllocator.init();
    defer direct_allocator.deinit();

    var plenty_of_memory = try direct_allocator.allocator.alloc(u8, 64 * 1024 * 1024);
    defer direct_allocator.allocator.free(plenty_of_memory);

    var fixed_buffer_allocator = std.heap.ThreadSafeFixedBufferAllocator.init(plenty_of_memory);
    var a = &fixed_buffer_allocator.allocator;

    var queue: Queue(i32) = undefined;
    queue.init();
    var context = Context {
        .allocator = a,
        .queue = &queue,
        .put_sum = 0,
        .get_sum = 0,
        .puts_done = 0,
        .get_count = 0,
    };

    var putters: [put_thread_count]&std.os.Thread = undefined;
    for (putters) |*t| {
        *t = try std.os.spawnThreadAllocator(a, &context, startPuts);
    }
    var getters: [put_thread_count]&std.os.Thread = undefined;
    for (getters) |*t| {
        *t = try std.os.spawnThreadAllocator(a, &context, startGets);
    }

    for (putters) |t| t.wait();
    _ = @atomicRmw(u8, &context.puts_done, builtin.AtomicRmwOp.Xchg, 1, AtomicOrder.SeqCst);
    for (getters) |t| t.wait();

    std.debug.assert(context.put_sum == context.get_sum);
    std.debug.assert(context.get_count == puts_per_thread * put_thread_count);
}

fn startPuts(ctx: &Context) u8 {
    var put_count: usize = puts_per_thread;
    var r = std.rand.DefaultPrng.init(0xdeadbeef);
    while (put_count != 0) : (put_count -= 1) {
        std.os.time.sleep(0, 1); // let the os scheduler be our fuzz
        const x = @bitCast(i32, r.random.scalar(u32));
        const node = ctx.allocator.create(Queue(i32).Node) catch unreachable;
        node.data = x;
        ctx.queue.put(node);
        _ = @atomicRmw(isize, &ctx.put_sum, builtin.AtomicRmwOp.Add, x, AtomicOrder.SeqCst);
    }
    return 0;
}

fn startGets(ctx: &Context) u8 {
    while (true) {
        while (ctx.queue.get()) |node| {
            std.os.time.sleep(0, 1); // let the os scheduler be our fuzz
            _ = @atomicRmw(isize, &ctx.get_sum, builtin.AtomicRmwOp.Add, node.data, builtin.AtomicOrder.SeqCst);
            _ = @atomicRmw(usize, &ctx.get_count, builtin.AtomicRmwOp.Add, 1, builtin.AtomicOrder.SeqCst);
        }

        if (@atomicLoad(u8, &ctx.puts_done, builtin.AtomicOrder.SeqCst) == 1) {
            break;
        }
    }
    return 0;
}