diff options
Diffstat (limited to 'lib/std/Thread.zig')
| -rw-r--r-- | lib/std/Thread.zig | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/lib/std/Thread.zig b/lib/std/Thread.zig index a3b469ad6f..e7edc17a35 100644 --- a/lib/std/Thread.zig +++ b/lib/std/Thread.zig @@ -28,6 +28,8 @@ else if (use_pthreads) PosixThreadImpl else if (target.os.tag == .linux) LinuxThreadImpl +else if (target.os.tag == .wasi) + WasiThreadImpl else UnsupportedImpl; @@ -266,6 +268,7 @@ pub const Id = switch (target.os.tag) { .freebsd, .openbsd, .haiku, + .wasi, => u32, .macos, .ios, .watchos, .tvos => u64, .windows => os.windows.DWORD, @@ -296,6 +299,8 @@ pub const SpawnConfig = struct { /// Size in bytes of the Thread's stack stack_size: usize = 16 * 1024 * 1024, + /// The allocator to be used to allocate memory for the to-be-spawned thread + allocator: ?std.mem.Allocator = null, }; pub const SpawnError = error{ @@ -733,6 +738,193 @@ const PosixThreadImpl = struct { } }; +const WasiThreadImpl = struct { + comptime { + // Sets the stack pointer, which is needed after creating a new thread + // to ensure the stack of the main thread isn't being poluted. + asm ( + \\ .text + \\ .export_name __set_stack_pointer, __set_stack_pointer + \\ .globaltype __stack_pointer, i32 + \\ .hidden wasi_thread_start + \\ .globl wasi_thread_start + \\ .type __set_stack_pointer, @function + \\ + \\ __set_stack_pointer: + \\ .functype __set_stack_pointer (i32) -> () + \\ local.get 0 # The raw pointer which replaces the stack pointer + \\ global.set __stack_pointer + \\ end_function + ); + } + thread: *WasiThread, + + pub const ThreadHandle = i32; + threadlocal var tls_thread_id: Id = 0; + + const WasiThread = struct { + tid: Atomic(i32) = Atomic(i32).init(0), + memory: []u8, + }; + + /// A meta-data structure used to bootstrap a thread + const Instance = struct { + thread: WasiThread, + /// Address of this `Instance` + base: usize, + /// Contains the pointer of the new __tls_base. + tls_base: usize, + /// Contains the pointer to the stack for the newly spawned thread. + stack_pointer: usize, + /// Contains the pointer to the wrapper which holds all arguments + /// for the callback. + raw_ptr: usize, + /// Function pointer to a wrapping function which will call the user's + /// function upon thread spawn. The above mentioned pointer will be passed + /// to this function pointer as its argument. + call_back: *const fn (usize) void, + }; + + fn getCurrentId() Id { + return tls_thread_id; + } + + fn getHandle(self: Impl) ThreadHandle { + return self.thread.tid; + } + + fn detach(self: Impl) void { + _ = self; + } + + fn join(self: Impl) void { + _ = self; + } + + fn spawn(config: std.Thread.SpawnConfig, comptime f: anytype, args: anytype) !WasiThreadImpl { + if (config.allocator == null) return error.OutOfMemory; // an allocator is required to spawn a WASI-thread + + // Wrapping struct required to hold the user-provided function arguments. + const Wrapper = struct { + args: @TypeOf(args), + fn entry(ptr: usize) void { + const w = @intToPtr(*@This(), ptr); + @call(.auto, f, w.args); + } + }; + + var guard_offset: usize = undefined; + var stack_offset: usize = undefined; + var tls_offset: usize = undefined; + var wrapper_offset: usize = undefined; + var instance_offset: usize = undefined; + + // Calculate the bytes we have to allocate to store all thread information, including: + // - The actual stack for the thread + // - The TLS segment + // - `Instance` - containing information about how to call the user's function. + const map_bytes = blk: { + var bytes: usize = std.wasm.page_size; + guard_offset = bytes; + + bytes = std.mem.alignForward(usize, bytes, 16); // align stack to 16 bytes + stack_offset = bytes; + bytes += @max(std.wasm.page_size, config.stack_size); + + bytes = std.mem.alignForward(usize, bytes, __tls_align()); + tls_offset = bytes; + bytes += __tls_size(); + + bytes = std.mem.alignForward(usize, bytes, @alignOf(Wrapper)); + wrapper_offset = bytes; + bytes += @sizeOf(Wrapper); + + bytes = std.mem.alignForward(usize, bytes, @alignOf(Instance)); + instance_offset = bytes; + bytes += @sizeOf(Instance); + + bytes = std.mem.alignForward(usize, bytes, std.wasm.page_size); + break :blk bytes; + }; + + // Allocate the amount of memory required for all meta data. + const allocated_memory = try config.allocator.?.alloc(u8, map_bytes); + + const wrapper = @ptrCast(*Wrapper, @alignCast(@alignOf(Wrapper), &allocated_memory[wrapper_offset])); + wrapper.* = .{ .args = args }; + + const instance = @ptrCast(*Instance, @alignCast(@alignOf(Instance), &allocated_memory[instance_offset])); + instance.* = .{ + .thread = .{ .memory = allocated_memory }, + .base = @ptrToInt(allocated_memory.ptr), + .tls_base = tls_offset, + .stack_pointer = stack_offset, + .raw_ptr = @ptrToInt(wrapper), + .call_back = &Wrapper.entry, + }; + + const tid = spawnWasiThread(instance); + // The specification says any value lower than 0 indicates an error. + // The values of such error are unspecified. WASI-Libc treats it as EAGAIN. + if (tid < 0) { + return error.SystemResources; + } + instance.thread.tid.store(tid, .SeqCst); + + return .{ .thread = &instance.thread }; + } + + export fn wasi_thread_start(tid: i32, arg: *const Instance) void { + __set_stack_pointer(arg.thread.memory.ptr + arg.stack_pointer); + __wasm_init_tls(arg.thread.memory.ptr + arg.tls_base); + WasiThreadImpl.tls_thread_id = @intCast(u32, tid); + + // finished bootstrapping, call user's procedure. + arg.call_back(arg.raw_ptr); + } + + // Asks the host to create a new thread for us. + // Newly created thread wil lcall `wasi_tread_start` with the thread ID as well + // as the input `arg` that was provided to `spawnWasiThread` + const spawnWasiThread = @"thread-spawn"; + extern "wasi" fn @"thread-spawn"(arg: *const Instance) i32; + + /// Initializes the TLS data segment starting at `memory`. + /// This is a synthetic function, generated by the linker. + extern fn __wasm_init_tls(memory: [*]u8) void; + extern fn __set_stack_pointer(ptr: [*]u8) void; + + /// Returns a pointer to the base of the TLS data segment for the current thread + inline fn __tls_base() [*]u8 { + return asm ( + \\ .globaltype __tls_base, i32 + \\ global.get __tls_base + \\ local.set %[ret] + : [ret] "=r" (-> [*]u8), + ); + } + + /// Returns the size of the TLS segment + inline fn __tls_size() u32 { + return asm volatile ( + \\ .globaltype __tls_size, i32, immutable + \\ global.get __tls_size + \\ local.set %[ret] + : [ret] "=r" (-> u32), + ); + } + + /// Returns the alignment of the TLS segment + inline fn __tls_align() u32 { + return asm ( + \\ .globaltype __tls_align, i32, immutable + \\ global.get __tls_align + \\ local.set %[ret] + : [ret] "=r" (-> u32), + ); + } +}; + const LinuxThreadImpl = struct { const linux = os.linux; |
