diff options
| author | ippsav <69125922+ippsav@users.noreply.github.com> | 2024-11-11 22:34:24 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-11 15:34:24 -0600 |
| commit | d346d074ebe5347f730a70d3a88b12f279bb405d (patch) | |
| tree | 7060ccbffc33f60591299f0d1f9e72795157970b /lib/std/Thread/Pool.zig | |
| parent | 28bdab385a80aaff736707a9227af2d1a17e7fa4 (diff) | |
| download | zig-d346d074ebe5347f730a70d3a88b12f279bb405d.tar.gz zig-d346d074ebe5347f730a70d3a88b12f279bb405d.zip | |
Enable thread_pool function to throw errors (#20260)
* std.ThreadPool: allow error union return type
* allow noreturn in Pool.zig
Diffstat (limited to 'lib/std/Thread/Pool.zig')
| -rw-r--r-- | lib/std/Thread/Pool.zig | 44 |
1 files changed, 36 insertions, 8 deletions
diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 86bac7ce46..4dd7513373 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -97,7 +97,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args wait_group.start(); if (builtin.single_threaded) { - @call(.auto, func, args); + callFn(func, args); wait_group.finish(); return; } @@ -112,7 +112,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args fn runFn(runnable: *Runnable, _: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); - @call(.auto, func, closure.arguments); + callFn(func, closure.arguments); closure.wait_group.finish(); // The thread pool's allocator is protected by the mutex. @@ -129,7 +129,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args const closure = pool.allocator.create(Closure) catch { pool.mutex.unlock(); - @call(.auto, func, args); + callFn(func, args); wait_group.finish(); return; }; @@ -160,7 +160,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar wait_group.start(); if (builtin.single_threaded) { - @call(.auto, func, .{0} ++ args); + callFn(func, .{0} ++ args); wait_group.finish(); return; } @@ -175,7 +175,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar fn runFn(runnable: *Runnable, id: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); - @call(.auto, func, .{id.?} ++ closure.arguments); + callFn(func, .{id.?} ++ closure.arguments); closure.wait_group.finish(); // The thread pool's allocator is protected by the mutex. @@ -193,7 +193,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar const closure = pool.allocator.create(Closure) catch { const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId()); pool.mutex.unlock(); - @call(.auto, func, .{id.?} ++ args); + callFn(func, .{id.?} ++ args); wait_group.finish(); return; }; @@ -213,7 +213,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { if (builtin.single_threaded) { - @call(.auto, func, args); + callFn(func, args); return; } @@ -226,7 +226,7 @@ pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { fn runFn(runnable: *Runnable, _: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); - @call(.auto, func, closure.arguments); + callFn(func, closure.arguments); // The thread pool's allocator is protected by the mutex. const mutex = &closure.pool.mutex; @@ -321,3 +321,31 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { pub fn getIdCount(pool: *Pool) usize { return @intCast(1 + pool.threads.len); } + +inline fn callFn(comptime f: anytype, args: anytype) void { + const bad_fn_ret = "expected return type of runFn to be 'void', '!void', noreturn, or !noreturn"; + + switch (@typeInfo(@typeInfo(@TypeOf(f)).@"fn".return_type.?)) { + .void, .noreturn => { + @call(.auto, f, args); + }, + .error_union => |info| { + switch (info.payload) { + void, noreturn => { + @call(.auto, f, args) catch |err| { + std.debug.print("error: {s}\n", .{@errorName(err)}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + }; + }, + else => { + @compileError(bad_fn_ret); + }, + } + }, + else => { + @compileError(bad_fn_ret); + }, + } +} |
