diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2023-02-13 13:39:06 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2023-03-15 10:48:12 -0700 |
| commit | 5b90fa05a4e5b155f25319713acfc67ad9516c69 (patch) | |
| tree | 6aaffe4ec16f7f6a18539bf2397176e001bf71cf /lib/std/Thread/Pool.zig | |
| parent | 0b744d7d670d00fa865ebd17847cbdc1a909ba70 (diff) | |
| download | zig-5b90fa05a4e5b155f25319713acfc67ad9516c69.tar.gz zig-5b90fa05a4e5b155f25319713acfc67ad9516c69.zip | |
extract ThreadPool and WaitGroup from compiler to std lib
Diffstat (limited to 'lib/std/Thread/Pool.zig')
| -rw-r--r-- | lib/std/Thread/Pool.zig | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig new file mode 100644 index 0000000000..930befbac5 --- /dev/null +++ b/lib/std/Thread/Pool.zig @@ -0,0 +1,152 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const Pool = @This(); +const WaitGroup = @import("WaitGroup.zig"); + +mutex: std.Thread.Mutex = .{}, +cond: std.Thread.Condition = .{}, +run_queue: RunQueue = .{}, +is_running: bool = true, +allocator: std.mem.Allocator, +threads: []std.Thread, + +const RunQueue = std.SinglyLinkedList(Runnable); +const Runnable = struct { + runFn: RunProto, +}; + +const RunProto = *const fn (*Runnable) void; + +pub fn init(pool: *Pool, allocator: std.mem.Allocator) !void { + pool.* = .{ + .allocator = allocator, + .threads = &[_]std.Thread{}, + }; + + if (builtin.single_threaded) { + return; + } + + const thread_count = std.math.max(1, std.Thread.getCpuCount() catch 1); + pool.threads = try allocator.alloc(std.Thread, thread_count); + errdefer allocator.free(pool.threads); + + // kill and join any threads we spawned previously on error. + var spawned: usize = 0; + errdefer pool.join(spawned); + + for (pool.threads) |*thread| { + thread.* = try std.Thread.spawn(.{}, worker, .{pool}); + spawned += 1; + } +} + +pub fn deinit(pool: *Pool) void { + pool.join(pool.threads.len); // kill and join all threads. + pool.* = undefined; +} + +fn join(pool: *Pool, spawned: usize) void { + if (builtin.single_threaded) { + return; + } + + { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + // ensure future worker threads exit the dequeue loop + pool.is_running = false; + } + + // wake up any sleeping threads (this can be done outside the mutex) + // then wait for all the threads we know are spawned to complete. + pool.cond.broadcast(); + for (pool.threads[0..spawned]) |thread| { + thread.join(); + } + + pool.allocator.free(pool.threads); +} + +pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { + if (builtin.single_threaded) { + @call(.auto, func, args); + return; + } + + const Args = @TypeOf(args); + const Closure = struct { + arguments: Args, + pool: *Pool, + run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } }, + + fn runFn(runnable: *Runnable) void { + const run_node = @fieldParentPtr(RunQueue.Node, "data", runnable); + const closure = @fieldParentPtr(@This(), "run_node", run_node); + @call(.auto, func, closure.arguments); + + // The thread pool's allocator is protected by the mutex. + const mutex = &closure.pool.mutex; + mutex.lock(); + defer mutex.unlock(); + + closure.pool.allocator.destroy(closure); + } + }; + + { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const closure = try pool.allocator.create(Closure); + closure.* = .{ + .arguments = args, + .pool = pool, + }; + + pool.run_queue.prepend(&closure.run_node); + } + + // Notify waiting threads outside the lock to try and keep the critical section small. + pool.cond.signal(); +} + +fn worker(pool: *Pool) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + while (true) { + while (pool.run_queue.popFirst()) |run_node| { + // Temporarily unlock the mutex in order to execute the run_node + pool.mutex.unlock(); + defer pool.mutex.lock(); + + const runFn = run_node.data.runFn; + runFn(&run_node.data); + } + + // Stop executing instead of waiting if the thread pool is no longer running. + if (pool.is_running) { + pool.cond.wait(&pool.mutex); + } else { + break; + } + } +} + +pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { + while (!wait_group.isDone()) { + if (blk: { + pool.mutex.lock(); + defer pool.mutex.unlock(); + break :blk pool.run_queue.popFirst(); + }) |run_node| { + run_node.data.runFn(&run_node.data); + continue; + } + + wait_group.wait(); + return; + } +} |
