diff options
| author | Jacob Young <jacobly0@users.noreply.github.com> | 2024-06-15 16:10:53 -0400 |
|---|---|---|
| committer | Jacob Young <jacobly0@users.noreply.github.com> | 2024-07-07 22:59:52 -0400 |
| commit | 525f341f33af9b8aad53931fd5511f00a82cb090 (patch) | |
| tree | cec3280498c1122858580946ac5e31f8feb807ce /lib/std/Thread | |
| parent | 8f20e81b8816aadd8ceb1b04bd3727cc1d124464 (diff) | |
| download | zig-525f341f33af9b8aad53931fd5511f00a82cb090.tar.gz zig-525f341f33af9b8aad53931fd5511f00a82cb090.zip | |
Zcu: introduce `PerThread` and pass to all the functions
Diffstat (limited to 'lib/std/Thread')
| -rw-r--r-- | lib/std/Thread/Pool.zig | 96 |
1 files changed, 86 insertions, 10 deletions
diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 846c7035a7..03ca8ffc8e 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -9,17 +9,19 @@ run_queue: RunQueue = .{}, is_running: bool = true, allocator: std.mem.Allocator, threads: []std.Thread, +ids: std.AutoArrayHashMapUnmanaged(std.Thread.Id, void), const RunQueue = std.SinglyLinkedList(Runnable); const Runnable = struct { runFn: RunProto, }; -const RunProto = *const fn (*Runnable) void; +const RunProto = *const fn (*Runnable, id: ?usize) void; pub const Options = struct { allocator: std.mem.Allocator, n_jobs: ?u32 = null, + track_ids: bool = false, }; pub fn init(pool: *Pool, options: Options) !void { @@ -28,6 +30,7 @@ pub fn init(pool: *Pool, options: Options) !void { pool.* = .{ .allocator = allocator, .threads = &[_]std.Thread{}, + .ids = .{}, }; if (builtin.single_threaded) { @@ -35,6 +38,10 @@ pub fn init(pool: *Pool, options: Options) !void { } const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1); + if (options.track_ids) { + try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count); + pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {}); + } // kill and join any threads we spawned and free memory on error. pool.threads = try allocator.alloc(std.Thread, thread_count); @@ -49,6 +56,7 @@ pub fn init(pool: *Pool, options: Options) !void { pub fn deinit(pool: *Pool) void { pool.join(pool.threads.len); // kill and join all threads. + pool.ids.deinit(pool.allocator); pool.* = undefined; } @@ -96,7 +104,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } }, wait_group: *WaitGroup, - fn runFn(runnable: *Runnable) void { + fn runFn(runnable: *Runnable, _: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); @call(.auto, func, closure.arguments); @@ -134,6 +142,70 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args pool.cond.signal(); } +/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and +/// `WaitGroup.finish` after it returns. +/// +/// The first argument passed to `func` is a dense `usize` thread id, the rest +/// of the arguments are passed from `args`. Requires the pool to have been +/// initialized with `.track_ids = true`. +/// +/// In the case that queuing the function call fails to allocate memory, or the +/// target is single-threaded, the function is called directly. +pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void { + wait_group.start(); + + if (builtin.single_threaded) { + @call(.auto, func, .{0} ++ args); + wait_group.finish(); + return; + } + + const Args = @TypeOf(args); + const Closure = struct { + arguments: Args, + pool: *Pool, + run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } }, + wait_group: *WaitGroup, + + fn runFn(runnable: *Runnable, id: ?usize) void { + const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); + const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); + @call(.auto, func, .{id.?} ++ closure.arguments); + closure.wait_group.finish(); + + // The thread pool's allocator is protected by the mutex. + const mutex = &closure.pool.mutex; + mutex.lock(); + defer mutex.unlock(); + + closure.pool.allocator.destroy(closure); + } + }; + + { + pool.mutex.lock(); + + const closure = pool.allocator.create(Closure) catch { + const id = pool.ids.getIndex(std.Thread.getCurrentId()); + pool.mutex.unlock(); + @call(.auto, func, .{id.?} ++ args); + wait_group.finish(); + return; + }; + closure.* = .{ + .arguments = args, + .pool = pool, + .wait_group = wait_group, + }; + + pool.run_queue.prepend(&closure.run_node); + pool.mutex.unlock(); + } + + // Notify waiting threads outside the lock to try and keep the critical section small. + pool.cond.signal(); +} + pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { if (builtin.single_threaded) { @call(.auto, func, args); @@ -181,14 +253,16 @@ fn worker(pool: *Pool) void { pool.mutex.lock(); defer pool.mutex.unlock(); + const id = if (pool.ids.count() > 0) pool.ids.count() else null; + if (id) |_| pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {}); + while (true) { while (pool.run_queue.popFirst()) |run_node| { // Temporarily unlock the mutex in order to execute the run_node pool.mutex.unlock(); defer pool.mutex.lock(); - const runFn = run_node.data.runFn; - runFn(&run_node.data); + run_node.data.runFn(&run_node.data, id); } // Stop executing instead of waiting if the thread pool is no longer running. @@ -201,16 +275,18 @@ fn worker(pool: *Pool) void { } pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { + var id: ?usize = null; + while (!wait_group.isDone()) { - if (blk: { - pool.mutex.lock(); - defer pool.mutex.unlock(); - break :blk pool.run_queue.popFirst(); - }) |run_node| { - run_node.data.runFn(&run_node.data); + pool.mutex.lock(); + if (pool.run_queue.popFirst()) |run_node| { + id = id orelse pool.ids.getIndex(std.Thread.getCurrentId()); + pool.mutex.unlock(); + run_node.data.runFn(&run_node.data, id); continue; } + pool.mutex.unlock(); wait_group.wait(); return; } |
