diff options
Diffstat (limited to 'lib/std/Thread.zig')
| -rw-r--r-- | lib/std/Thread.zig | 1106 |
1 files changed, 662 insertions, 444 deletions
diff --git a/lib/std/Thread.zig b/lib/std/Thread.zig index 06fe2a84dc..91f7ff58c3 100644 --- a/lib/std/Thread.zig +++ b/lib/std/Thread.zig @@ -8,7 +8,11 @@ //! primitives that operate on kernel threads. For concurrency primitives that support //! both evented I/O and async I/O, see the respective names in the top level std namespace. -data: Data, +const std = @import("std.zig"); +const os = std.os; +const assert = std.debug.assert; +const target = std.Target.current; +const Atomic = std.atomic.Atomic; pub const AutoResetEvent = @import("Thread/AutoResetEvent.zig"); pub const Futex = @import("Thread/Futex.zig"); @@ -18,117 +22,51 @@ pub const Mutex = @import("Thread/Mutex.zig"); pub const Semaphore = @import("Thread/Semaphore.zig"); pub const Condition = @import("Thread/Condition.zig"); -pub const use_pthreads = std.Target.current.os.tag != .windows and builtin.link_libc; +pub const spinLoopHint = @compileError("deprecated: use std.atomic.spinLoopHint"); -const Thread = @This(); -const std = @import("std.zig"); -const builtin = std.builtin; -const os = std.os; -const mem = std.mem; -const windows = std.os.windows; -const c = std.c; -const assert = std.debug.assert; +pub const use_pthreads = target.os.tag != .windows and std.builtin.link_libc; -const bad_startfn_ret = "expected return type of startFn to be 'u8', 'noreturn', 'void', or '!void'"; +const Thread = @This(); +const Impl = if (target.os.tag == .windows) + WindowsThreadImpl +else if (use_pthreads) + PosixThreadImpl +else if (target.os.tag == .linux) + LinuxThreadImpl +else + UnsupportedImpl; -/// Represents a kernel thread handle. -/// May be an integer or a pointer depending on the platform. -/// On Linux and POSIX, this is the same as Id. -pub const Handle = if (use_pthreads) - c.pthread_t -else switch (std.Target.current.os.tag) { - .linux => i32, - .windows => windows.HANDLE, - else => void, -}; +impl: Impl, /// Represents a unique ID per thread. -/// May be an integer or pointer depending on the platform. -/// On Linux and POSIX, this is the same as Handle. -pub const Id = switch (std.Target.current.os.tag) { - .windows => windows.DWORD, - else => Handle, -}; +pub const Id = u64; -pub const Data = if (use_pthreads) - struct { - handle: Thread.Handle, - memory: []u8, - } -else switch (std.Target.current.os.tag) { - .linux => struct { - handle: Thread.Handle, - memory: []align(mem.page_size) u8, - }, - .windows => struct { - handle: Thread.Handle, - alloc_start: *c_void, - heap_handle: windows.HANDLE, - }, - else => struct {}, -}; - -pub const spinLoopHint = @compileError("deprecated: use std.atomic.spinLoopHint"); - -/// Returns the ID of the calling thread. -/// Makes a syscall every time the function is called. -/// On Linux and POSIX, this Id is the same as a Handle. +/// Returns the platform ID of the callers thread. +/// Attempts to use thread locals and avoid syscalls when possible. pub fn getCurrentId() Id { - if (use_pthreads) { - return c.pthread_self(); - } else return switch (std.Target.current.os.tag) { - .linux => os.linux.gettid(), - .windows => windows.kernel32.GetCurrentThreadId(), - else => @compileError("Unsupported OS"), - }; + return Impl.getCurrentId(); } -/// Returns the handle of this thread. -/// On Linux and POSIX, this is the same as Id. -/// On Linux, it is possible that the thread spawned with `spawn` -/// finishes executing entirely before the clone syscall completes. In this -/// case, this function will return 0 rather than the no-longer-existing thread's -/// pid. -pub fn handle(self: Thread) Handle { - return self.data.handle; -} +pub const CpuCountError = error{ + PermissionDenied, + SystemResources, + Unexpected, +}; -pub fn wait(self: *Thread) void { - if (use_pthreads) { - const err = c.pthread_join(self.data.handle, null); - switch (err) { - 0 => {}, - os.EINVAL => unreachable, - os.ESRCH => unreachable, - os.EDEADLK => unreachable, - else => unreachable, - } - std.heap.c_allocator.free(self.data.memory); - std.heap.c_allocator.destroy(self); - } else switch (std.Target.current.os.tag) { - .linux => { - while (true) { - const pid_value = @atomicLoad(i32, &self.data.handle, .SeqCst); - if (pid_value == 0) break; - const rc = os.linux.futex_wait(&self.data.handle, os.linux.FUTEX_WAIT, pid_value, null); - switch (os.linux.getErrno(rc)) { - 0 => continue, - os.EINTR => continue, - os.EAGAIN => continue, - else => unreachable, - } - } - os.munmap(self.data.memory); - }, - .windows => { - windows.WaitForSingleObjectEx(self.data.handle, windows.INFINITE, false) catch unreachable; - windows.CloseHandle(self.data.handle); - windows.HeapFree(self.data.heap_handle, 0, self.data.alloc_start); - }, - else => @compileError("Unsupported OS"), - } +/// Returns the platforms view on the number of logical CPU cores available. +pub fn getCpuCount() CpuCountError!usize { + return Impl.getCpuCount(); } +/// Configuration options for hints on how to spawn threads. +pub const SpawnConfig = struct { + // TODO compile-time call graph analysis to determine stack upper bound + // https://github.com/ziglang/zig/issues/157 + + /// Size in bytes of the Thread's stack + stack_size: usize = 16 * 1024 * 1024, +}; + pub const SpawnError = error{ /// A system-imposed limit on the number of threads was encountered. /// There are a number of limits that may trigger this error: @@ -159,248 +97,552 @@ pub const SpawnError = error{ Unexpected, }; -// Given `T`, the type of the thread startFn, extract the expected type for the -// context parameter. -fn SpawnContextType(comptime T: type) type { - const TI = @typeInfo(T); - if (TI != .Fn) - @compileError("expected function type, found " ++ @typeName(T)); +/// Spawns a new thread which executes `function` using `args` and returns a handle the spawned thread. +/// `config` can be used as hints to the platform for now to spawn and execute the `function`. +/// The caller must eventually either call `join()` to wait for the thread to finish and free its resources +/// or call `detach()` to excuse the caller from calling `join()` and have the thread clean up its resources on completion`. +pub fn spawn(config: SpawnConfig, comptime function: anytype, args: anytype) SpawnError!Thread { + if (std.builtin.single_threaded) { + @compileError("Cannot spawn thread when building in single-threaded mode"); + } + + const impl = try Impl.spawn(config, function, args); + return Thread{ .impl = impl }; +} - if (TI.Fn.args.len != 1) - @compileError("expected function with single argument, found " ++ @typeName(T)); +/// Represents a kernel thread handle. +/// May be an integer or a pointer depending on the platform. +pub const Handle = Impl.ThreadHandle; - return TI.Fn.args[0].arg_type orelse - @compileError("cannot use a generic function as thread startFn"); +/// Retrns the handle of this thread +pub fn getHandle(self: Thread) Handle { + return self.impl.getHandle(); } -/// Spawns a new thread executing startFn, returning an handle for it. -/// Caller must call wait on the returned thread. -/// The `startFn` function must take a single argument of type T and return a -/// value of type u8, noreturn, void or !void. -/// The `context` parameter is of type T and is passed to the spawned thread. -pub fn spawn(comptime startFn: anytype, context: SpawnContextType(@TypeOf(startFn))) SpawnError!*Thread { - if (builtin.single_threaded) @compileError("cannot spawn thread when building in single-threaded mode"); - // TODO compile-time call graph analysis to determine stack upper bound - // https://github.com/ziglang/zig/issues/157 - const default_stack_size = 16 * 1024 * 1024; +/// Release the obligation of the caller to call `join()` and have the thread clean up its own resources on completion. +/// Once called, this consumes the Thread object and invoking any other functions on it is considered undefined behavior. +pub fn detach(self: Thread) void { + return self.impl.detach(); +} - const Context = @TypeOf(context); +/// Waits for the thread to complete, then deallocates any resources created on `spawn()`. +/// Once called, this consumes the Thread object and invoking any other functions on it is considered undefined behavior. +pub fn join(self: Thread) void { + return self.impl.join(); +} - if (std.Target.current.os.tag == .windows) { - const WinThread = struct { - const OuterContext = struct { - thread: Thread, - inner: Context, - }; - fn threadMain(raw_arg: windows.LPVOID) callconv(.C) windows.DWORD { - const arg = if (@sizeOf(Context) == 0) undefined // - else @ptrCast(*Context, @alignCast(@alignOf(Context), raw_arg)).*; - - switch (@typeInfo(@typeInfo(@TypeOf(startFn)).Fn.return_type.?)) { - .NoReturn => { - startFn(arg); - }, - .Void => { - startFn(arg); - return 0; - }, - .Int => |info| { - if (info.bits != 8) { - @compileError(bad_startfn_ret); - } - return startFn(arg); - }, - .ErrorUnion => |info| { - if (info.payload != void) { - @compileError(bad_startfn_ret); - } - startFn(arg) catch |err| { - std.debug.warn("error: {s}\n", .{@errorName(err)}); - if (@errorReturnTrace()) |trace| { - std.debug.dumpStackTrace(trace.*); - } - }; - return 0; - }, - else => @compileError(bad_startfn_ret), +/// State to synchronize detachment of spawner thread to spawned thread +const Completion = Atomic(enum(u8) { + running, + detached, + completed, +}); + +/// Used by the Thread implementations to call the spawned function with the arguments. +fn callFn(comptime f: anytype, args: anytype) switch (Impl) { + WindowsThreadImpl => std.os.windows.DWORD, + LinuxThreadImpl => u8, + PosixThreadImpl => ?*c_void, + else => unreachable, +} { + const default_value = if (Impl == PosixThreadImpl) null else 0; + const bad_fn_ret = "expected return type of startFn to be 'u8', 'noreturn', 'void', or '!void'"; + + switch (@typeInfo(@typeInfo(@TypeOf(f)).Fn.return_type.?)) { + .NoReturn => { + @call(.{}, f, args); + }, + .Void => { + @call(.{}, f, args); + return default_value; + }, + .Int => |info| { + if (info.bits != 8) { + @compileError(bad_fn_ret); + } + + const status = @call(.{}, f, args); + if (Impl != PosixThreadImpl) { + return status; + } + + // pthreads don't support exit status, ignore value + _ = status; + return default_value; + }, + .ErrorUnion => |info| { + if (info.payload != void) { + @compileError(bad_fn_ret); + } + + @call(.{}, f, args) catch |err| { + std.debug.warn("error: {s}\n", .{@errorName(err)}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); } + }; + + return default_value; + }, + else => { + @compileError(bad_fn_ret); + }, + } +} + +/// We can't compile error in the `Impl` switch statement as its eagerly evaluated. +/// So instead, we compile-error on the methods themselves for platforms which don't support threads. +const UnsupportedImpl = struct { + pub const ThreadHandle = void; + + fn getCurrentId() u64 { + return unsupported({}); + } + + fn getCpuCount() !usize { + return unsupported({}); + } + + fn spawn(config: SpawnConfig, comptime f: anytype, args: anytype) !Impl { + return unsupported(.{ config, f, args }); + } + + fn getHandle(self: Impl) ThreadHandle { + return unsupported(self); + } + + fn detach(self: Impl) void { + return unsupported(self); + } + + fn join(self: Impl) void { + return unsupported(self); + } + + fn unsupported(unusued: anytype) noreturn { + @compileLog("Unsupported operating system", target.os.tag); + _ = unusued; + unreachable; + } +}; + +const WindowsThreadImpl = struct { + const windows = os.windows; + + pub const ThreadHandle = windows.HANDLE; + + fn getCurrentId() u64 { + return windows.kernel32.GetCurrentThreadId(); + } + + fn getCpuCount() !usize { + // Faster than calling into GetSystemInfo(), even if amortized. + return windows.peb().NumberOfProcessors; + } + + thread: *ThreadCompletion, + + const ThreadCompletion = struct { + completion: Completion, + heap_ptr: windows.PVOID, + heap_handle: windows.HANDLE, + thread_handle: windows.HANDLE = undefined, + + fn free(self: ThreadCompletion) void { + const status = windows.kernel32.HeapFree(self.heap_handle, 0, self.heap_ptr); + assert(status != 0); + } + }; + + fn spawn(config: SpawnConfig, comptime f: anytype, args: anytype) !Impl { + const Args = @TypeOf(args); + const Instance = struct { + fn_args: Args, + thread: ThreadCompletion, + + fn entryFn(raw_ptr: windows.PVOID) callconv(.C) windows.DWORD { + const self = @ptrCast(*@This(), @alignCast(@alignOf(@This()), raw_ptr)); + defer switch (self.thread.completion.swap(.completed, .SeqCst)) { + .running => {}, + .completed => unreachable, + .detached => self.thread.free(), + }; + return callFn(f, self.fn_args); } }; const heap_handle = windows.kernel32.GetProcessHeap() orelse return error.OutOfMemory; - const byte_count = @alignOf(WinThread.OuterContext) + @sizeOf(WinThread.OuterContext); - const bytes_ptr = windows.kernel32.HeapAlloc(heap_handle, 0, byte_count) orelse return error.OutOfMemory; - errdefer assert(windows.kernel32.HeapFree(heap_handle, 0, bytes_ptr) != 0); - const bytes = @ptrCast([*]u8, bytes_ptr)[0..byte_count]; - const outer_context = std.heap.FixedBufferAllocator.init(bytes).allocator.create(WinThread.OuterContext) catch unreachable; - outer_context.* = WinThread.OuterContext{ - .thread = Thread{ - .data = Thread.Data{ - .heap_handle = heap_handle, - .alloc_start = bytes_ptr, - .handle = undefined, - }, + const alloc_bytes = @alignOf(Instance) + @sizeOf(Instance); + const alloc_ptr = windows.kernel32.HeapAlloc(heap_handle, 0, alloc_bytes) orelse return error.OutOfMemory; + errdefer assert(windows.kernel32.HeapFree(heap_handle, 0, alloc_ptr) != 0); + + const instance_bytes = @ptrCast([*]u8, alloc_ptr)[0..alloc_bytes]; + const instance = std.heap.FixedBufferAllocator.init(instance_bytes).allocator.create(Instance) catch unreachable; + instance.* = .{ + .fn_args = args, + .thread = .{ + .completion = Completion.init(.running), + .heap_ptr = alloc_ptr, + .heap_handle = heap_handle, }, - .inner = context, }; - const parameter = if (@sizeOf(Context) == 0) null else @ptrCast(*c_void, &outer_context.inner); - outer_context.thread.data.handle = windows.kernel32.CreateThread(null, default_stack_size, WinThread.threadMain, parameter, 0, null) orelse { - switch (windows.kernel32.GetLastError()) { - else => |err| return windows.unexpectedError(err), - } + // Windows appears to only support SYSTEM_INFO.dwAllocationGranularity minimum stack size. + // Going lower makes it default to that specified in the executable (~1mb). + // Its also fine if the limit here is incorrect as stack size is only a hint. + var stack_size = std.math.cast(u32, config.stack_size) catch std.math.maxInt(u32); + stack_size = std.math.max(64 * 1024, stack_size); + + instance.thread.thread_handle = windows.kernel32.CreateThread( + null, + stack_size, + Instance.entryFn, + @ptrCast(*c_void, instance), + 0, + null, + ) orelse { + const errno = windows.kernel32.GetLastError(); + return windows.unexpectedError(errno); }; - return &outer_context.thread; + + return Impl{ .thread = &instance.thread }; } - const MainFuncs = struct { - fn linuxThreadMain(ctx_addr: usize) callconv(.C) u8 { - const arg = if (@sizeOf(Context) == 0) undefined // - else @intToPtr(*Context, ctx_addr).*; + fn getHandle(self: Impl) ThreadHandle { + return self.thread.thread_handle; + } - switch (@typeInfo(@typeInfo(@TypeOf(startFn)).Fn.return_type.?)) { - .NoReturn => { - startFn(arg); - }, - .Void => { - startFn(arg); - return 0; - }, - .Int => |info| { - if (info.bits != 8) { - @compileError(bad_startfn_ret); - } - return startFn(arg); - }, - .ErrorUnion => |info| { - if (info.payload != void) { - @compileError(bad_startfn_ret); - } - startFn(arg) catch |err| { - std.debug.warn("error: {s}\n", .{@errorName(err)}); - if (@errorReturnTrace()) |trace| { - std.debug.dumpStackTrace(trace.*); - } - }; - return 0; - }, - else => @compileError(bad_startfn_ret), - } + fn detach(self: Impl) void { + windows.CloseHandle(self.thread.thread_handle); + switch (self.thread.completion.swap(.detached, .SeqCst)) { + .running => {}, + .completed => self.thread.free(), + .detached => unreachable, } - fn posixThreadMain(ctx: ?*c_void) callconv(.C) ?*c_void { - const arg = if (@sizeOf(Context) == 0) undefined // - else @ptrCast(*Context, @alignCast(@alignOf(Context), ctx)).*; + } - switch (@typeInfo(@typeInfo(@TypeOf(startFn)).Fn.return_type.?)) { - .NoReturn => { - startFn(arg); - }, - .Void => { - startFn(arg); - return null; - }, - .Int => |info| { - if (info.bits != 8) { - @compileError(bad_startfn_ret); - } - // pthreads don't support exit status, ignore value - _ = startFn(arg); - return null; - }, - .ErrorUnion => |info| { - if (info.payload != void) { - @compileError(bad_startfn_ret); - } - startFn(arg) catch |err| { - std.debug.warn("error: {s}\n", .{@errorName(err)}); - if (@errorReturnTrace()) |trace| { - std.debug.dumpStackTrace(trace.*); - } - }; - return null; - }, - else => @compileError(bad_startfn_ret), - } + fn join(self: Impl) void { + windows.WaitForSingleObjectEx(self.thread.thread_handle, windows.INFINITE, false) catch unreachable; + windows.CloseHandle(self.thread.thread_handle); + assert(self.thread.completion.load(.SeqCst) == .completed); + self.thread.free(); + } +}; + +const PosixThreadImpl = struct { + const c = std.c; + + pub const ThreadHandle = c.pthread_t; + + fn getCurrentId() Id { + switch (target.os.tag) { + .linux => { + return LinuxThreadImpl.getCurrentId(); + }, + .macos, .ios, .watchos, .tvos => { + var thread_id: u64 = undefined; + // Pass thread=null to get the current thread ID. + assert(c.pthread_threadid_np(null, &thread_id) == 0); + return thread_id; + }, + .dragonfly => { + return @bitCast(u32, c.lwp_gettid()); + }, + .netbsd => { + return @bitCast(u32, c._lwp_self()); + }, + .freebsd => { + return @bitCast(u32, c.pthread_getthreadid_np()); + }, + .openbsd => { + return @bitCast(u32, c.getthrid()); + }, + .haiku => { + return @bitCast(u32, c.find_thread(null)); + }, + else => { + return @ptrToInt(c.pthread_self()); + }, } - }; + } + + fn getCpuCount() !usize { + switch (target.os.tag) { + .linux => { + return LinuxThreadImpl.getCpuCount(); + }, + .openbsd => { + var count: c_int = undefined; + var count_size: usize = @sizeOf(c_int); + const mib = [_]c_int{ os.CTL_HW, os.HW_NCPUONLINE }; + os.sysctl(&mib, &count, &count_size, null, 0) catch |err| switch (err) { + error.NameTooLong, error.UnknownName => unreachable, + else => |e| return e, + }; + return @intCast(usize, count); + }, + .haiku => { + var count: u32 = undefined; + var system_info: os.system_info = undefined; + _ = os.system.get_system_info(&system_info); // always returns B_OK + count = system_info.cpu_count; + return @intCast(usize, count); + }, + else => { + var count: c_int = undefined; + var count_len: usize = @sizeOf(c_int); + const name = if (comptime target.isDarwin()) "hw.logicalcpu" else "hw.ncpu"; + os.sysctlbynameZ(name, &count, &count_len, null, 0) catch |err| switch (err) { + error.NameTooLong, error.UnknownName => unreachable, + else => |e| return e, + }; + return @intCast(usize, count); + }, + } + } + + handle: ThreadHandle, + + fn spawn(config: SpawnConfig, comptime f: anytype, args: anytype) !Impl { + const Args = @TypeOf(args); + const allocator = std.heap.c_allocator; + + const Instance = struct { + fn entryFn(raw_arg: ?*c_void) callconv(.C) ?*c_void { + // @alignCast() below doesn't support zero-sized-types (ZST) + if (@sizeOf(Args) < 1) { + return callFn(f, @as(Args, undefined)); + } + + const args_ptr = @ptrCast(*Args, @alignCast(@alignOf(Args), raw_arg)); + defer allocator.destroy(args_ptr); + return callFn(f, args_ptr.*); + } + }; + + const args_ptr = try allocator.create(Args); + args_ptr.* = args; + errdefer allocator.destroy(args_ptr); - if (Thread.use_pthreads) { var attr: c.pthread_attr_t = undefined; if (c.pthread_attr_init(&attr) != 0) return error.SystemResources; defer assert(c.pthread_attr_destroy(&attr) == 0); - const thread_obj = try std.heap.c_allocator.create(Thread); - errdefer std.heap.c_allocator.destroy(thread_obj); - if (@sizeOf(Context) > 0) { - thread_obj.data.memory = try std.heap.c_allocator.allocAdvanced( - u8, - @alignOf(Context), - @sizeOf(Context), - .at_least, - ); - errdefer std.heap.c_allocator.free(thread_obj.data.memory); - mem.copy(u8, thread_obj.data.memory, mem.asBytes(&context)); - } else { - thread_obj.data.memory = @as([*]u8, undefined)[0..0]; - } - // Use the same set of parameters used by the libc-less impl. - assert(c.pthread_attr_setstacksize(&attr, default_stack_size) == 0); - assert(c.pthread_attr_setguardsize(&attr, mem.page_size) == 0); + const stack_size = std.math.max(config.stack_size, 16 * 1024); + assert(c.pthread_attr_setstacksize(&attr, stack_size) == 0); + assert(c.pthread_attr_setguardsize(&attr, std.mem.page_size) == 0); - const err = c.pthread_create( - &thread_obj.data.handle, + var handle: c.pthread_t = undefined; + switch (c.pthread_create( + &handle, &attr, - MainFuncs.posixThreadMain, - thread_obj.data.memory.ptr, - ); - switch (err) { - 0 => return thread_obj, + Instance.entryFn, + if (@sizeOf(Args) > 1) @ptrCast(*c_void, args_ptr) else undefined, + )) { + 0 => return Impl{ .handle = handle }, os.EAGAIN => return error.SystemResources, os.EPERM => unreachable, os.EINVAL => unreachable, - else => return os.unexpectedErrno(err), + else => |err| return os.unexpectedErrno(err), + } + } + + fn getHandle(self: Impl) ThreadHandle { + return self.handle; + } + + fn detach(self: Impl) void { + switch (c.pthread_detach(self.handle)) { + 0 => {}, + os.EINVAL => unreachable, // thread handle is not joinable + os.ESRCH => unreachable, // thread handle is invalid + else => unreachable, + } + } + + fn join(self: Impl) void { + switch (c.pthread_join(self.handle, null)) { + 0 => {}, + os.EINVAL => unreachable, // thread handle is not joinable (or another thread is already joining in) + os.ESRCH => unreachable, // thread handle is invalid + os.EDEADLK => unreachable, // two threads tried to join each other + else => unreachable, } + } +}; + +const LinuxThreadImpl = struct { + const linux = os.linux; - return thread_obj; + pub const ThreadHandle = i32; + + threadlocal var tls_thread_id: ?Id = null; + + fn getCurrentId() Id { + return tls_thread_id orelse { + const tid = @bitCast(u32, linux.gettid()); + tls_thread_id = tid; + return tid; + }; } - var guard_end_offset: usize = undefined; - var stack_end_offset: usize = undefined; - var thread_start_offset: usize = undefined; - var context_start_offset: usize = undefined; - var tls_start_offset: usize = undefined; - const mmap_len = blk: { - var l: usize = mem.page_size; - // Allocate a guard page right after the end of the stack region - guard_end_offset = l; - // The stack itself, which grows downwards. - l = mem.alignForward(l + default_stack_size, mem.page_size); - stack_end_offset = l; - // Above the stack, so that it can be in the same mmap call, put the Thread object. - l = mem.alignForward(l, @alignOf(Thread)); - thread_start_offset = l; - l += @sizeOf(Thread); - // Next, the Context object. - if (@sizeOf(Context) != 0) { - l = mem.alignForward(l, @alignOf(Context)); - context_start_offset = l; - l += @sizeOf(Context); + fn getCpuCount() !usize { + const cpu_set = try os.sched_getaffinity(0); + // TODO: should not need this usize cast + return @as(usize, os.CPU_COUNT(cpu_set)); + } + + thread: *ThreadCompletion, + + const ThreadCompletion = struct { + completion: Completion = Completion.init(.running), + child_tid: Atomic(i32) = Atomic(i32).init(1), + parent_tid: i32 = undefined, + mapped: []align(std.mem.page_size) u8, + + /// Calls `munmap(mapped.ptr, mapped.len)` then `exit(1)` without touching the stack (which lives in `mapped.ptr`). + /// Ported over from musl libc's pthread detached implementation: + /// https://github.com/ifduyue/musl/search?q=__unmapself + fn freeAndExit(self: *ThreadCompletion) noreturn { + const unmap_and_exit: []const u8 = switch (target.cpu.arch) { + .i386 => ( + \\ movl $91, %%eax + \\ movl %[ptr], %%ebx + \\ movl %[len], %%ecx + \\ int $128 + \\ movl $1, %%eax + \\ movl $0, %%ebx + \\ int $128 + ), + .x86_64 => ( + \\ movq $11, %%rax + \\ movq %[ptr], %%rbx + \\ movq %[len], %%rcx + \\ syscall + \\ movq $60, %%rax + \\ movq $1, %%rdi + \\ syscall + ), + .arm, .armeb, .thumb, .thumbeb => ( + \\ mov r7, #91 + \\ mov r0, %[ptr] + \\ mov r1, %[len] + \\ svc 0 + \\ mov r7, #1 + \\ mov r0, #0 + \\ svc 0 + ), + .aarch64, .aarch64_be, .aarch64_32 => ( + \\ mov x8, #215 + \\ mov x0, %[ptr] + \\ mov x1, %[len] + \\ svc 0 + \\ mov x8, #93 + \\ mov x0, #0 + \\ svc 0 + ), + .mips, .mipsel => ( + \\ move $sp, $25 + \\ li $2, 4091 + \\ move $4, %[ptr] + \\ move $5, %[len] + \\ syscall + \\ li $2, 4001 + \\ li $4, 0 + \\ syscall + ), + .mips64, .mips64el => ( + \\ li $2, 4091 + \\ move $4, %[ptr] + \\ move $5, %[len] + \\ syscall + \\ li $2, 4001 + \\ li $4, 0 + \\ syscall + ), + .powerpc, .powerpcle, .powerpc64, .powerpc64le => ( + \\ li 0, 91 + \\ mr %[ptr], 3 + \\ mr %[len], 4 + \\ sc + \\ li 0, 1 + \\ li 3, 0 + \\ sc + \\ blr + ), + .riscv64 => ( + \\ li a7, 215 + \\ mv a0, %[ptr] + \\ mv a1, %[len] + \\ ecall + \\ li a7, 93 + \\ mv a0, zero + \\ ecall + ), + else => |cpu_arch| { + @compileLog("Unsupported linux arch ", cpu_arch); + }, + }; + + asm volatile (unmap_and_exit + : + : [ptr] "r" (@ptrToInt(self.mapped.ptr)), + [len] "r" (self.mapped.len) + : "memory" + ); + + unreachable; } - // Finally, the Thread Local Storage, if any. - l = mem.alignForward(l, os.linux.tls.tls_image.alloc_align); - tls_start_offset = l; - l += os.linux.tls.tls_image.alloc_size; - // Round the size to the page size. - break :blk mem.alignForward(l, mem.page_size); }; - const mmap_slice = mem: { - // Map the whole stack with no rw permissions to avoid - // committing the whole region right away - const mmap_slice = os.mmap( + fn spawn(config: SpawnConfig, comptime f: anytype, args: anytype) !Impl { + const Args = @TypeOf(args); + const Instance = struct { + fn_args: Args, + thread: ThreadCompletion, + + fn entryFn(raw_arg: usize) callconv(.C) u8 { + const self = @intToPtr(*@This(), raw_arg); + defer switch (self.thread.completion.swap(.completed, .SeqCst)) { + .running => {}, + .completed => unreachable, + .detached => self.thread.freeAndExit(), + }; + return callFn(f, self.fn_args); + } + }; + + var guard_offset: usize = undefined; + var stack_offset: usize = undefined; + var tls_offset: usize = undefined; + var instance_offset: usize = undefined; + + const map_bytes = blk: { + var bytes: usize = std.mem.page_size; + guard_offset = bytes; + + bytes += std.math.max(std.mem.page_size, config.stack_size); + bytes = std.mem.alignForward(bytes, std.mem.page_size); + stack_offset = bytes; + + bytes = std.mem.alignForward(bytes, linux.tls.tls_image.alloc_align); + tls_offset = bytes; + bytes += linux.tls.tls_image.alloc_size; + + bytes = std.mem.alignForward(bytes, @alignOf(Instance)); + instance_offset = bytes; + bytes += @sizeOf(Instance); + + bytes = std.mem.alignForward(bytes, std.mem.page_size); + break :blk bytes; + }; + + // map all memory needed without read/write permissions + // to avoid committing the whole region right away + const mapped = os.mmap( null, - mmap_len, + map_bytes, os.PROT_NONE, os.MAP_PRIVATE | os.MAP_ANONYMOUS, -1, @@ -411,73 +653,57 @@ pub fn spawn(comptime startFn: anytype, context: SpawnContextType(@TypeOf(startF error.PermissionDenied => unreachable, else => |e| return e, }; - errdefer os.munmap(mmap_slice); + assert(mapped.len >= map_bytes); + errdefer os.munmap(mapped); - // Map everything but the guard page as rw + // map everything but the guard page as read/write os.mprotect( - mmap_slice[guard_end_offset..], + mapped[guard_offset..], os.PROT_READ | os.PROT_WRITE, ) catch |err| switch (err) { error.AccessDenied => unreachable, else => |e| return e, }; - break :mem mmap_slice; - }; - - const mmap_addr = @ptrToInt(mmap_slice.ptr); - - const thread_ptr = @alignCast(@alignOf(Thread), @intToPtr(*Thread, mmap_addr + thread_start_offset)); - thread_ptr.data.memory = mmap_slice; + // Prepare the TLS segment and prepare a user_desc struct when needed on i386 + var tls_ptr = os.linux.tls.prepareTLS(mapped[tls_offset..]); + var user_desc: if (target.cpu.arch == .i386) os.linux.user_desc else void = undefined; + if (target.cpu.arch == .i386) { + defer tls_ptr = @ptrToInt(&user_desc); + user_desc = .{ + .entry_number = os.linux.tls.tls_image.gdt_entry_number, + .base_addr = tls_ptr, + .limit = 0xfffff, + .seg_32bit = 1, + .contents = 0, // Data + .read_exec_only = 0, + .limit_in_pages = 1, + .seg_not_present = 0, + .useable = 1, + }; + } - var arg: usize = undefined; - if (@sizeOf(Context) != 0) { - arg = mmap_addr + context_start_offset; - const context_ptr = @alignCast(@alignOf(Context), @intToPtr(*Context, arg)); - context_ptr.* = context; - } + const instance = @ptrCast(*Instance, @alignCast(@alignOf(Instance), &mapped[instance_offset])); + instance.* = .{ + .fn_args = args, + .thread = .{ .mapped = mapped }, + }; - if (std.Target.current.os.tag == .linux) { - const flags: u32 = os.CLONE_VM | os.CLONE_FS | os.CLONE_FILES | - os.CLONE_SIGHAND | os.CLONE_THREAD | os.CLONE_SYSVSEM | + const flags: u32 = os.CLONE_THREAD | os.CLONE_DETACHED | + os.CLONE_VM | os.CLONE_FS | os.CLONE_FILES | os.CLONE_PARENT_SETTID | os.CLONE_CHILD_CLEARTID | - os.CLONE_DETACHED | os.CLONE_SETTLS; - // This structure is only needed when targeting i386 - var user_desc: if (std.Target.current.cpu.arch == .i386) os.linux.user_desc else void = undefined; - - const tls_area = mmap_slice[tls_start_offset..]; - const tp_value = os.linux.tls.prepareTLS(tls_area); - - const newtls = blk: { - if (std.Target.current.cpu.arch == .i386) { - user_desc = os.linux.user_desc{ - .entry_number = os.linux.tls.tls_image.gdt_entry_number, - .base_addr = tp_value, - .limit = 0xfffff, - .seg_32bit = 1, - .contents = 0, // Data - .read_exec_only = 0, - .limit_in_pages = 1, - .seg_not_present = 0, - .useable = 1, - }; - break :blk @ptrToInt(&user_desc); - } else { - break :blk tp_value; - } - }; + os.CLONE_SIGHAND | os.CLONE_SYSVSEM | os.CLONE_SETTLS; - const rc = os.linux.clone( - MainFuncs.linuxThreadMain, - mmap_addr + stack_end_offset, + switch (linux.getErrno(linux.clone( + Instance.entryFn, + @ptrToInt(&mapped[stack_offset]), flags, - arg, - &thread_ptr.data.handle, - newtls, - &thread_ptr.data.handle, - ); - switch (os.errno(rc)) { - 0 => return thread_ptr, + @ptrToInt(instance), + &instance.thread.parent_tid, + tls_ptr, + &instance.thread.child_tid.value, + ))) { + 0 => return Impl{ .thread = &instance.thread }, os.EAGAIN => return error.ThreadQuotaExceeded, os.EINVAL => unreachable, os.ENOMEM => return error.SystemResources, @@ -486,100 +712,92 @@ pub fn spawn(comptime startFn: anytype, context: SpawnContextType(@TypeOf(startF os.EUSERS => unreachable, else => |err| return os.unexpectedErrno(err), } - } else { - @compileError("Unsupported OS"); } -} -pub const CpuCountError = error{ - PermissionDenied, - SystemResources, - Unexpected, -}; + fn getHandle(self: Impl) ThreadHandle { + return self.thread.parent_tid; + } -pub fn cpuCount() CpuCountError!usize { - switch (std.Target.current.os.tag) { - .linux => { - const cpu_set = try os.sched_getaffinity(0); - return @as(usize, os.CPU_COUNT(cpu_set)); // TODO should not need this usize cast - }, - .windows => { - return os.windows.peb().NumberOfProcessors; - }, - .openbsd => { - var count: c_int = undefined; - var count_size: usize = @sizeOf(c_int); - const mib = [_]c_int{ os.CTL_HW, os.HW_NCPUONLINE }; - os.sysctl(&mib, &count, &count_size, null, 0) catch |err| switch (err) { - error.NameTooLong, error.UnknownName => unreachable, - else => |e| return e, - }; - return @intCast(usize, count); - }, - .haiku => { - var count: u32 = undefined; - // var system_info: os.system_info = undefined; - // const rc = os.system.get_system_info(&system_info); - count = system_info.cpu_count; - return @intCast(usize, count); - }, - else => { - var count: c_int = undefined; - var count_len: usize = @sizeOf(c_int); - const name = if (comptime std.Target.current.isDarwin()) "hw.logicalcpu" else "hw.ncpu"; - os.sysctlbynameZ(name, &count, &count_len, null, 0) catch |err| switch (err) { - error.NameTooLong, error.UnknownName => unreachable, - else => |e| return e, - }; - return @intCast(usize, count); - }, + fn detach(self: Impl) void { + switch (self.thread.completion.swap(.detached, .SeqCst)) { + .running => {}, + .completed => self.join(), + .detached => unreachable, + } } -} -pub fn getCurrentThreadId() u64 { - switch (std.Target.current.os.tag) { - .linux => { - // Use the syscall directly as musl doesn't provide a wrapper. - return @bitCast(u32, os.linux.gettid()); - }, - .windows => { - return os.windows.kernel32.GetCurrentThreadId(); - }, - .macos, .ios, .watchos, .tvos => { - var thread_id: u64 = undefined; - // Pass thread=null to get the current thread ID. - assert(c.pthread_threadid_np(null, &thread_id) == 0); - return thread_id; - }, - .dragonfly => { - return @bitCast(u32, c.lwp_gettid()); - }, - .netbsd => { - return @bitCast(u32, c._lwp_self()); - }, - .freebsd => { - return @bitCast(u32, c.pthread_getthreadid_np()); - }, - .openbsd => { - return @bitCast(u32, c.getthrid()); - }, - .haiku => { - return @bitCast(u32, c.find_thread(null)); - }, - else => { - @compileError("getCurrentThreadId not implemented for this platform"); - }, + fn join(self: Impl) void { + defer os.munmap(self.thread.mapped); + + var spin: u8 = 10; + while (true) { + const tid = self.thread.child_tid.load(.SeqCst); + if (tid == 0) { + break; + } + + if (spin > 0) { + spin -= 1; + std.atomic.spinLoopHint(); + continue; + } + + switch (linux.getErrno(linux.futex_wait( + &self.thread.child_tid.value, + linux.FUTEX_WAIT, + tid, + null, + ))) { + 0 => continue, + os.EINTR => continue, + os.EAGAIN => continue, + else => unreachable, + } + } } -} +}; test "std.Thread" { - if (!builtin.single_threaded) { - _ = AutoResetEvent; - _ = Futex; - _ = ResetEvent; - _ = StaticResetEvent; - _ = Mutex; - _ = Semaphore; - _ = Condition; - } + // Doesn't use testing.refAllDecls() since that would pull in the compileError spinLoopHint. + _ = AutoResetEvent; + _ = Futex; + _ = ResetEvent; + _ = StaticResetEvent; + _ = Mutex; + _ = Semaphore; + _ = Condition; +} + +fn testIncrementNotify(value: *usize, event: *ResetEvent) void { + value.* += 1; + event.set(); +} + +test "Thread.join" { + if (std.builtin.single_threaded) return error.SkipZigTest; + + var value: usize = 0; + var event: ResetEvent = undefined; + try event.init(); + defer event.deinit(); + + const thread = try Thread.spawn(.{}, testIncrementNotify, .{ &value, &event }); + thread.join(); + + try std.testing.expectEqual(value, 1); +} + +test "Thread.detach" { + if (std.builtin.single_threaded) return error.SkipZigTest; + + var value: usize = 0; + var event: ResetEvent = undefined; + try event.init(); + defer event.deinit(); + + const thread = try Thread.spawn(.{}, testIncrementNotify, .{ &value, &event }); + thread.detach(); + + event.wait(); + try std.testing.expectEqual(value, 1); } |
