diff options
Diffstat (limited to 'src/ThreadPool.zig')
| -rw-r--r-- | src/ThreadPool.zig | 81 |
1 files changed, 49 insertions, 32 deletions
diff --git a/src/ThreadPool.zig b/src/ThreadPool.zig index 55e40ea287..7115adbddd 100644 --- a/src/ThreadPool.zig +++ b/src/ThreadPool.zig @@ -1,6 +1,7 @@ const std = @import("std"); const builtin = @import("builtin"); const ThreadPool = @This(); +const WaitGroup = @import("WaitGroup.zig"); mutex: std.Thread.Mutex = .{}, cond: std.Thread.Condition = .{}, @@ -19,8 +20,8 @@ const RunProto = switch (builtin.zig_backend) { else => *const fn (*Runnable) void, }; -pub fn init(self: *ThreadPool, allocator: std.mem.Allocator) !void { - self.* = .{ +pub fn init(pool: *ThreadPool, allocator: std.mem.Allocator) !void { + pool.* = .{ .allocator = allocator, .threads = &[_]std.Thread{}, }; @@ -30,48 +31,48 @@ pub fn init(self: *ThreadPool, allocator: std.mem.Allocator) !void { } const thread_count = std.math.max(1, std.Thread.getCpuCount() catch 1); - self.threads = try allocator.alloc(std.Thread, thread_count); - errdefer allocator.free(self.threads); + pool.threads = try allocator.alloc(std.Thread, thread_count); + errdefer allocator.free(pool.threads); // kill and join any threads we spawned previously on error. var spawned: usize = 0; - errdefer self.join(spawned); + errdefer pool.join(spawned); - for (self.threads) |*thread| { - thread.* = try std.Thread.spawn(.{}, worker, .{self}); + for (pool.threads) |*thread| { + thread.* = try std.Thread.spawn(.{}, worker, .{pool}); spawned += 1; } } -pub fn deinit(self: *ThreadPool) void { - self.join(self.threads.len); // kill and join all threads. - self.* = undefined; +pub fn deinit(pool: *ThreadPool) void { + pool.join(pool.threads.len); // kill and join all threads. + pool.* = undefined; } -fn join(self: *ThreadPool, spawned: usize) void { +fn join(pool: *ThreadPool, spawned: usize) void { if (builtin.single_threaded) { return; } { - self.mutex.lock(); - defer self.mutex.unlock(); + pool.mutex.lock(); + defer pool.mutex.unlock(); // ensure future worker threads exit the dequeue loop - self.is_running = false; + pool.is_running = false; } // wake up any sleeping threads (this can be done outside the mutex) // then wait for all the threads we know are spawned to complete. - self.cond.broadcast(); - for (self.threads[0..spawned]) |thread| { + pool.cond.broadcast(); + for (pool.threads[0..spawned]) |thread| { thread.join(); } - self.allocator.free(self.threads); + pool.allocator.free(pool.threads); } -pub fn spawn(self: *ThreadPool, comptime func: anytype, args: anytype) !void { +pub fn spawn(pool: *ThreadPool, comptime func: anytype, args: anytype) !void { if (builtin.single_threaded) { @call(.{}, func, args); return; @@ -98,41 +99,57 @@ pub fn spawn(self: *ThreadPool, comptime func: anytype, args: anytype) !void { }; { - self.mutex.lock(); - defer self.mutex.unlock(); + pool.mutex.lock(); + defer pool.mutex.unlock(); - const closure = try self.allocator.create(Closure); + const closure = try pool.allocator.create(Closure); closure.* = .{ .arguments = args, - .pool = self, + .pool = pool, }; - self.run_queue.prepend(&closure.run_node); + pool.run_queue.prepend(&closure.run_node); } // Notify waiting threads outside the lock to try and keep the critical section small. - self.cond.signal(); + pool.cond.signal(); } -fn worker(self: *ThreadPool) void { - self.mutex.lock(); - defer self.mutex.unlock(); +fn worker(pool: *ThreadPool) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); while (true) { - while (self.run_queue.popFirst()) |run_node| { + while (pool.run_queue.popFirst()) |run_node| { // Temporarily unlock the mutex in order to execute the run_node - self.mutex.unlock(); - defer self.mutex.lock(); + pool.mutex.unlock(); + defer pool.mutex.lock(); const runFn = run_node.data.runFn; runFn(&run_node.data); } // Stop executing instead of waiting if the thread pool is no longer running. - if (self.is_running) { - self.cond.wait(&self.mutex); + if (pool.is_running) { + pool.cond.wait(&pool.mutex); } else { break; } } } + +pub fn waitAndWork(pool: *ThreadPool, wait_group: *WaitGroup) void { + while (!wait_group.isDone()) { + if (blk: { + pool.mutex.lock(); + defer pool.mutex.unlock(); + break :blk pool.run_queue.popFirst(); + }) |run_node| { + run_node.data.runFn(&run_node.data); + continue; + } + + wait_group.wait(); + return; + } +} |
