aboutsummaryrefslogtreecommitdiff
path: root/lib/std/event/batch.zig
blob: ba50d4bee5ab13d0ec48c382c70e9d85b0270e7b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
const std = @import("../std.zig");
const testing = std.testing;

/// Performs multiple async functions in parallel, without heap allocation.
/// Async function frames are managed externally to this abstraction, and
/// passed in via the `add` function. Once all the jobs are added, call `wait`.
/// This API is *not* thread-safe. The object must be accessed from one thread at
/// a time, however, it need not be the same thread.
pub fn Batch(
    /// The return value for each job.
    /// If a job slot was re-used due to maxed out concurrency, then its result
    /// value will be overwritten. The values can be accessed with the `results` field.
    comptime Result: type,
    /// How many jobs to run in parallel.
    comptime max_jobs: comptime_int,
    /// Controls whether the `add` and `wait` functions will be async functions.
    comptime async_behavior: enum {
        /// Observe the value of `std.io.is_async` to decide whether `add`
        /// and `wait` will be async functions. Asserts that the jobs do not suspend when
        /// `std.io.mode == .blocking`. This is a generally safe assumption, and the
        /// usual recommended option for this parameter.
        auto_async,

        /// Always uses the `nosuspend` keyword when using `await` on the jobs,
        /// making `add` and `wait` non-async functions. Asserts that the jobs do not suspend.
        never_async,

        /// `add` and `wait` use regular `await` keyword, making them async functions.
        always_async,
    },
) type {
    return struct {
        jobs: [max_jobs]Job,
        next_job_index: usize,
        collected_result: CollectedResult,

        const Job = struct {
            frame: ?anyframe->Result,
            result: Result,
        };

        const Self = @This();

        const CollectedResult = switch (@typeInfo(Result)) {
            .ErrorUnion => Result,
            else => void,
        };

        const async_ok = switch (async_behavior) {
            .auto_async => std.io.is_async,
            .never_async => false,
            .always_async => true,
        };

        pub fn init() Self {
            return Self{
                .jobs = [1]Job{
                    .{
                        .frame = null,
                        .result = undefined,
                    },
                } ** max_jobs,
                .next_job_index = 0,
                .collected_result = {},
            };
        }

        /// Add a frame to the Batch. If all jobs are in-flight, then this function
        /// waits until one completes.
        /// This function is *not* thread-safe. It must be called from one thread at
        /// a time, however, it need not be the same thread.
        /// TODO: "select" language feature to use the next available slot, rather than
        /// awaiting the next index.
        pub fn add(self: *Self, frame: anyframe->Result) void {
            const job = &self.jobs[self.next_job_index];
            self.next_job_index = (self.next_job_index + 1) % max_jobs;
            if (job.frame) |existing| {
                job.result = if (async_ok) await existing else nosuspend await existing;
                if (CollectedResult != void) {
                    job.result catch |err| {
                        self.collected_result = err;
                    };
                }
            }
            job.frame = frame;
        }

        /// Wait for all the jobs to complete.
        /// Safe to call any number of times.
        /// If `Result` is an error union, this function returns the last error that occurred, if any.
        /// Unlike the `results` field, the return value of `wait` will report any error that occurred;
        /// hitting max parallelism will not compromise the result.
        /// This function is *not* thread-safe. It must be called from one thread at
        /// a time, however, it need not be the same thread.
        pub fn wait(self: *Self) CollectedResult {
            for (self.jobs) |*job|
                if (job.frame) |f| {
                    job.result = if (async_ok) await f else nosuspend await f;
                    if (CollectedResult != void) {
                        job.result catch |err| {
                            self.collected_result = err;
                        };
                    }
                    job.frame = null;
                };
            return self.collected_result;
        }
    };
}

test "std.event.Batch" {
    if (@import("builtin").zig_backend != .stage1) return error.SkipZigTest;
    var count: usize = 0;
    var batch = Batch(void, 2, .auto_async).init();
    batch.add(&async sleepALittle(&count));
    batch.add(&async increaseByTen(&count));
    batch.wait();
    try testing.expect(count == 11);

    var another = Batch(anyerror!void, 2, .auto_async).init();
    another.add(&async somethingElse());
    another.add(&async doSomethingThatFails());
    try testing.expectError(error.ItBroke, another.wait());
}

fn sleepALittle(count: *usize) void {
    std.time.sleep(1 * std.time.ns_per_ms);
    _ = @atomicRmw(usize, count, .Add, 1, .SeqCst);
}

fn increaseByTen(count: *usize) void {
    var i: usize = 0;
    while (i < 10) : (i += 1) {
        _ = @atomicRmw(usize, count, .Add, 1, .SeqCst);
    }
}

fn doSomethingThatFails() anyerror!void {}
fn somethingElse() anyerror!void {
    return error.ItBroke;
}