aboutsummaryrefslogtreecommitdiff
path: root/lib/std/Thread/Pool.zig
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2023-02-13 13:39:06 -0700
committerAndrew Kelley <andrew@ziglang.org>2023-03-15 10:48:12 -0700
commit5b90fa05a4e5b155f25319713acfc67ad9516c69 (patch)
tree6aaffe4ec16f7f6a18539bf2397176e001bf71cf /lib/std/Thread/Pool.zig
parent0b744d7d670d00fa865ebd17847cbdc1a909ba70 (diff)
downloadzig-5b90fa05a4e5b155f25319713acfc67ad9516c69.tar.gz
zig-5b90fa05a4e5b155f25319713acfc67ad9516c69.zip
extract ThreadPool and WaitGroup from compiler to std lib
Diffstat (limited to 'lib/std/Thread/Pool.zig')
-rw-r--r--lib/std/Thread/Pool.zig152
1 files changed, 152 insertions, 0 deletions
diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig
new file mode 100644
index 0000000000..930befbac5
--- /dev/null
+++ b/lib/std/Thread/Pool.zig
@@ -0,0 +1,152 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const Pool = @This();
+const WaitGroup = @import("WaitGroup.zig");
+
+mutex: std.Thread.Mutex = .{},
+cond: std.Thread.Condition = .{},
+run_queue: RunQueue = .{},
+is_running: bool = true,
+allocator: std.mem.Allocator,
+threads: []std.Thread,
+
+const RunQueue = std.SinglyLinkedList(Runnable);
+const Runnable = struct {
+ runFn: RunProto,
+};
+
+const RunProto = *const fn (*Runnable) void;
+
+pub fn init(pool: *Pool, allocator: std.mem.Allocator) !void {
+ pool.* = .{
+ .allocator = allocator,
+ .threads = &[_]std.Thread{},
+ };
+
+ if (builtin.single_threaded) {
+ return;
+ }
+
+ const thread_count = std.math.max(1, std.Thread.getCpuCount() catch 1);
+ pool.threads = try allocator.alloc(std.Thread, thread_count);
+ errdefer allocator.free(pool.threads);
+
+ // kill and join any threads we spawned previously on error.
+ var spawned: usize = 0;
+ errdefer pool.join(spawned);
+
+ for (pool.threads) |*thread| {
+ thread.* = try std.Thread.spawn(.{}, worker, .{pool});
+ spawned += 1;
+ }
+}
+
+pub fn deinit(pool: *Pool) void {
+ pool.join(pool.threads.len); // kill and join all threads.
+ pool.* = undefined;
+}
+
+fn join(pool: *Pool, spawned: usize) void {
+ if (builtin.single_threaded) {
+ return;
+ }
+
+ {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ // ensure future worker threads exit the dequeue loop
+ pool.is_running = false;
+ }
+
+ // wake up any sleeping threads (this can be done outside the mutex)
+ // then wait for all the threads we know are spawned to complete.
+ pool.cond.broadcast();
+ for (pool.threads[0..spawned]) |thread| {
+ thread.join();
+ }
+
+ pool.allocator.free(pool.threads);
+}
+
+pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
+ if (builtin.single_threaded) {
+ @call(.auto, func, args);
+ return;
+ }
+
+ const Args = @TypeOf(args);
+ const Closure = struct {
+ arguments: Args,
+ pool: *Pool,
+ run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },
+
+ fn runFn(runnable: *Runnable) void {
+ const run_node = @fieldParentPtr(RunQueue.Node, "data", runnable);
+ const closure = @fieldParentPtr(@This(), "run_node", run_node);
+ @call(.auto, func, closure.arguments);
+
+ // The thread pool's allocator is protected by the mutex.
+ const mutex = &closure.pool.mutex;
+ mutex.lock();
+ defer mutex.unlock();
+
+ closure.pool.allocator.destroy(closure);
+ }
+ };
+
+ {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ const closure = try pool.allocator.create(Closure);
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ };
+
+ pool.run_queue.prepend(&closure.run_node);
+ }
+
+ // Notify waiting threads outside the lock to try and keep the critical section small.
+ pool.cond.signal();
+}
+
+fn worker(pool: *Pool) void {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ while (true) {
+ while (pool.run_queue.popFirst()) |run_node| {
+ // Temporarily unlock the mutex in order to execute the run_node
+ pool.mutex.unlock();
+ defer pool.mutex.lock();
+
+ const runFn = run_node.data.runFn;
+ runFn(&run_node.data);
+ }
+
+ // Stop executing instead of waiting if the thread pool is no longer running.
+ if (pool.is_running) {
+ pool.cond.wait(&pool.mutex);
+ } else {
+ break;
+ }
+ }
+}
+
+pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
+ while (!wait_group.isDone()) {
+ if (blk: {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+ break :blk pool.run_queue.popFirst();
+ }) |run_node| {
+ run_node.data.runFn(&run_node.data);
+ continue;
+ }
+
+ wait_group.wait();
+ return;
+ }
+}