aboutsummaryrefslogtreecommitdiff
path: root/lib/std/event/channel.zig
blob: e1c147d25a3d03a329d85081168acc705eda6648 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
const std = @import("../std.zig");
const builtin = @import("builtin");
const assert = std.debug.assert;
const testing = std.testing;
const Loop = std.event.Loop;

/// Many producer, many consumer, thread-safe, runtime configurable buffer size.
/// When buffer is empty, consumers suspend and are resumed by producers.
/// When buffer is full, producers suspend and are resumed by consumers.
pub fn Channel(comptime T: type) type {
    return struct {
        getters: std.atomic.Queue(GetNode),
        or_null_queue: std.atomic.Queue(*std.atomic.Queue(GetNode).Node),
        putters: std.atomic.Queue(PutNode),
        get_count: usize,
        put_count: usize,
        dispatch_lock: bool,
        need_dispatch: bool,

        // simple fixed size ring buffer
        buffer_nodes: []T,
        buffer_index: usize,
        buffer_len: usize,

        const SelfChannel = @This();
        const GetNode = struct {
            tick_node: *Loop.NextTickNode,
            data: Data,

            const Data = union(enum) {
                Normal: Normal,
                OrNull: OrNull,
            };

            const Normal = struct {
                ptr: *T,
            };

            const OrNull = struct {
                ptr: *?T,
                or_null: *std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node,
            };
        };
        const PutNode = struct {
            data: T,
            tick_node: *Loop.NextTickNode,
        };

        const global_event_loop = Loop.instance orelse
            @compileError("std.event.Channel currently only works with event-based I/O");

        /// Call `deinit` to free resources when done.
        /// `buffer` must live until `deinit` is called.
        /// For a zero length buffer, use `[0]T{}`.
        /// TODO https://github.com/ziglang/zig/issues/2765
        pub fn init(self: *SelfChannel, buffer: []T) void {
            // The ring buffer implementation only works with power of 2 buffer sizes
            // because of relying on subtracting across zero. For example (0 -% 1) % 10 == 5
            assert(buffer.len == 0 or @popCount(buffer.len) == 1);

            self.* = SelfChannel{
                .buffer_len = 0,
                .buffer_nodes = buffer,
                .buffer_index = 0,
                .dispatch_lock = false,
                .need_dispatch = false,
                .getters = std.atomic.Queue(GetNode).init(),
                .putters = std.atomic.Queue(PutNode).init(),
                .or_null_queue = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).init(),
                .get_count = 0,
                .put_count = 0,
            };
        }

        /// Must be called when all calls to put and get have suspended and no more calls occur.
        /// This can be omitted if caller can guarantee that the suspended putters and getters
        /// do not need to be run to completion. Note that this may leave awaiters hanging.
        pub fn deinit(self: *SelfChannel) void {
            while (self.getters.get()) |get_node| {
                resume get_node.data.tick_node.data;
            }
            while (self.putters.get()) |put_node| {
                resume put_node.data.tick_node.data;
            }
            self.* = undefined;
        }

        /// puts a data item in the channel. The function returns when the value has been added to the
        /// buffer, or in the case of a zero size buffer, when the item has been retrieved by a getter.
        /// Or when the channel is destroyed.
        pub fn put(self: *SelfChannel, data: T) void {
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var queue_node = std.atomic.Queue(PutNode).Node{
                .data = PutNode{
                    .tick_node = &my_tick_node,
                    .data = data,
                },
            };

            suspend {
                self.putters.put(&queue_node);
                _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);

                self.dispatch();
            }
        }

        /// await this function to get an item from the channel. If the buffer is empty, the frame will
        /// complete when the next item is put in the channel.
        pub fn get(self: *SelfChannel) callconv(.Async) T {
            // TODO https://github.com/ziglang/zig/issues/2765
            var result: T = undefined;
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var queue_node = std.atomic.Queue(GetNode).Node{
                .data = GetNode{
                    .tick_node = &my_tick_node,
                    .data = GetNode.Data{
                        .Normal = GetNode.Normal{ .ptr = &result },
                    },
                },
            };

            suspend {
                self.getters.put(&queue_node);
                _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);

                self.dispatch();
            }
            return result;
        }

        //pub async fn select(comptime EnumUnion: type, channels: ...) EnumUnion {
        //    assert(@memberCount(EnumUnion) == channels.len); // enum union and channels mismatch
        //    assert(channels.len != 0); // enum unions cannot have 0 fields
        //    if (channels.len == 1) {
        //        const result = await (async channels[0].get() catch unreachable);
        //        return @unionInit(EnumUnion, @memberName(EnumUnion, 0), result);
        //    }
        //}

        /// Get an item from the channel. If the buffer is empty and there are no
        /// puts waiting, this returns `null`.
        pub fn getOrNull(self: *SelfChannel) ?T {
            // TODO integrate this function with named return values
            // so we can get rid of this extra result copy
            var result: ?T = null;
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var or_null_node = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node{ .data = undefined };
            var queue_node = std.atomic.Queue(GetNode).Node{
                .data = GetNode{
                    .tick_node = &my_tick_node,
                    .data = GetNode.Data{
                        .OrNull = GetNode.OrNull{
                            .ptr = &result,
                            .or_null = &or_null_node,
                        },
                    },
                },
            };
            or_null_node.data = &queue_node;

            suspend {
                self.getters.put(&queue_node);
                _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
                self.or_null_queue.put(&or_null_node);

                self.dispatch();
            }
            return result;
        }

        fn dispatch(self: *SelfChannel) void {
            // set the "need dispatch" flag
            @atomicStore(bool, &self.need_dispatch, true, .SeqCst);

            lock: while (true) {
                // set the lock flag
                if (@atomicRmw(bool, &self.dispatch_lock, .Xchg, true, .SeqCst)) return;

                // clear the need_dispatch flag since we're about to do it
                @atomicStore(bool, &self.need_dispatch, false, .SeqCst);

                while (true) {
                    one_dispatch: {
                        // later we correct these extra subtractions
                        var get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                        var put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);

                        // transfer self.buffer to self.getters
                        while (self.buffer_len != 0) {
                            if (get_count == 0) break :one_dispatch;

                            const get_node = &self.getters.get().?.data;
                            switch (get_node.data) {
                                GetNode.Data.Normal => |info| {
                                    info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
                                },
                                GetNode.Data.OrNull => |info| {
                                    _ = self.or_null_queue.remove(info.or_null);
                                    info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
                                },
                            }
                            global_event_loop.onNextTick(get_node.tick_node);
                            self.buffer_len -= 1;

                            get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                        }

                        // direct transfer self.putters to self.getters
                        while (get_count != 0 and put_count != 0) {
                            const get_node = &self.getters.get().?.data;
                            const put_node = &self.putters.get().?.data;

                            switch (get_node.data) {
                                GetNode.Data.Normal => |info| {
                                    info.ptr.* = put_node.data;
                                },
                                GetNode.Data.OrNull => |info| {
                                    _ = self.or_null_queue.remove(info.or_null);
                                    info.ptr.* = put_node.data;
                                },
                            }
                            global_event_loop.onNextTick(get_node.tick_node);
                            global_event_loop.onNextTick(put_node.tick_node);

                            get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                            put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
                        }

                        // transfer self.putters to self.buffer
                        while (self.buffer_len != self.buffer_nodes.len and put_count != 0) {
                            const put_node = &self.putters.get().?.data;

                            self.buffer_nodes[self.buffer_index % self.buffer_nodes.len] = put_node.data;
                            global_event_loop.onNextTick(put_node.tick_node);
                            self.buffer_index +%= 1;
                            self.buffer_len += 1;

                            put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
                        }
                    }

                    // undo the extra subtractions
                    _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
                    _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);

                    // All the "get or null" functions should resume now.
                    var remove_count: usize = 0;
                    while (self.or_null_queue.get()) |or_null_node| {
                        remove_count += @boolToInt(self.getters.remove(or_null_node.data));
                        global_event_loop.onNextTick(or_null_node.data.data.tick_node);
                    }
                    if (remove_count != 0) {
                        _ = @atomicRmw(usize, &self.get_count, .Sub, remove_count, .SeqCst);
                    }

                    // clear need-dispatch flag
                    if (@atomicRmw(bool, &self.need_dispatch, .Xchg, false, .SeqCst)) continue;

                    assert(@atomicRmw(bool, &self.dispatch_lock, .Xchg, false, .SeqCst));

                    // we have to check again now that we unlocked
                    if (@atomicLoad(bool, &self.need_dispatch, .SeqCst)) continue :lock;

                    return;
                }
            }
        }
    };
}

test "std.event.Channel" {
    if (!std.io.is_async) return error.SkipZigTest;

    // https://github.com/ziglang/zig/issues/1908
    if (builtin.single_threaded) return error.SkipZigTest;

    // https://github.com/ziglang/zig/issues/3251
    if (builtin.os.tag == .freebsd) return error.SkipZigTest;

    var channel: Channel(i32) = undefined;
    channel.init(&[0]i32{});
    defer channel.deinit();

    var handle = async testChannelGetter(&channel);
    var putter = async testChannelPutter(&channel);

    await handle;
    await putter;
}

test "std.event.Channel wraparound" {

    // TODO provide a way to run tests in evented I/O mode
    if (!std.io.is_async) return error.SkipZigTest;

    const channel_size = 2;

    var buf: [channel_size]i32 = undefined;
    var channel: Channel(i32) = undefined;
    channel.init(&buf);
    defer channel.deinit();

    // add items to channel and pull them out until
    // the buffer wraps around, make sure it doesn't crash.
    channel.put(5);
    try testing.expectEqual(@as(i32, 5), channel.get());
    channel.put(6);
    try testing.expectEqual(@as(i32, 6), channel.get());
    channel.put(7);
    try testing.expectEqual(@as(i32, 7), channel.get());
}
fn testChannelGetter(channel: *Channel(i32)) callconv(.Async) void {
    const value1 = channel.get();
    try testing.expect(value1 == 1234);

    const value2 = channel.get();
    try testing.expect(value2 == 4567);

    const value3 = channel.getOrNull();
    try testing.expect(value3 == null);

    var last_put = async testPut(channel, 4444);
    const value4 = channel.getOrNull();
    try testing.expect(value4.? == 4444);
    await last_put;
}
fn testChannelPutter(channel: *Channel(i32)) callconv(.Async) void {
    channel.put(1234);
    channel.put(4567);
}
fn testPut(channel: *Channel(i32), value: i32) callconv(.Async) void {
    channel.put(value);
}