From f229b740999b58432dc49e3aa412fac14e3781f3 Mon Sep 17 00:00:00 2001 From: Igor Anić Date: Thu, 17 Nov 2022 20:58:45 +0100 Subject: stdlib: fix condition variable broadcast FutexImpl fixes #12877 Current implementation (before this fix) observes number of waiters when broadcast occurs and then makes that number of wakeups. If we have multiple threads waiting for wakeup which immediately go into wait if wakeup is not for that thread (as described in the issue). The same thread can get multiple wakeups while some got none. That is not consistent with documented behavior for condition variable broadcast: `Unblocks all threads currently blocked in a call to wait() or timedWait() with a given Mutex.`. This fix ensures that the thread waiting on futext is woken up on futex wake. --- lib/std/Thread/Condition.zig | 148 ++++++++++++++++++++++++++++++------------- 1 file changed, 105 insertions(+), 43 deletions(-) (limited to 'lib/std/Thread') diff --git a/lib/std/Thread/Condition.zig b/lib/std/Thread/Condition.zig index 1482c8166d..6829a9f15c 100644 --- a/lib/std/Thread/Condition.zig +++ b/lib/std/Thread/Condition.zig @@ -194,59 +194,50 @@ const FutexImpl = struct { const signal_mask = 0xffff << 16; fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void { - // Register that we're waiting on the state by incrementing the wait count. - // This assumes that there can be at most ((1<<16)-1) or 65,355 threads concurrently waiting on the same Condvar. - // If this is hit in practice, then this condvar not working is the least of your concerns. + // Observe the epoch, then check the state again to see if we should wake up. + // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock: + // + // - T1: s = LOAD(&state) + // - T2: UPDATE(&s, signal) + // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch) + // - T1: e = LOAD(&epoch) (was reordered after the state load) + // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) + // + // Acquire barrier to ensure the epoch load happens before the state load. + const epoch = self.epoch.load(.Acquire); var state = self.state.fetchAdd(one_waiter, .Monotonic); assert(state & waiter_mask != waiter_mask); state += one_waiter; + var futex_deadline = Futex.Deadline.init(timeout); - // Temporarily release the mutex in order to block on the condition variable. mutex.unlock(); defer mutex.lock(); - var futex_deadline = Futex.Deadline.init(timeout); - while (true) { - // Try to wake up by consuming a signal and decremented the waiter we added previously. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; - } + futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { + // On timeout, we must decrement the waiter we added above. + error.Timeout => { + while (true) { + // If there's a signal when we're timing out, consume it and report being woken up instead. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; + } - // Observe the epoch, then check the state again to see if we should wake up. - // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock: - // - // - T1: s = LOAD(&state) - // - T2: UPDATE(&s, signal) - // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch) - // - T1: e = LOAD(&epoch) (was reordered after the state load) - // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) - // - // Acquire barrier to ensure the epoch load happens before the state load. - const epoch = self.epoch.load(.Acquire); - state = self.state.load(.Monotonic); + // Remove the waiter we added and officially return timed out. + const new_state = state - one_waiter; + state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; + } + }, + }; + + while (true) { + // Wait thread, decrement waiter and consume signal if exists. + var new_state = state - one_waiter; if (state & signal_mask != 0) { - continue; + new_state = state - one_signal; } - - futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { - // On timeout, we must decrement the waiter we added above. - error.Timeout => { - while (true) { - // If there's a signal when we're timing out, consume it and report being woken up instead. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; - } - - // Remove the waiter we added and officially return timed out. - const new_state = state - one_waiter; - state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; - } - }, - }; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } } @@ -536,3 +527,74 @@ test "Condition - broadcasting" { t.join(); } } + +test "Condition - broadcasting - wake all threads" { + // Tests issue #12877 + // This test requires spawning threads + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + const num_threads = 10; + + const BroadcastTest = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + thread_id_to_wake: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + + fn run(self: *@This(), thread_id: usize) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last broadcast thread to start tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + while (self.thread_id_to_wake != thread_id) { + self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); + self.wakeups += 1; + } + if (self.thread_id_to_wake <= num_threads) { + // Signal next thread to wake up. + self.thread_id_to_wake += 1; + self.cond.broadcast(); + } + } + }; + + var broadcast_test = BroadcastTest{}; + var thread_id: usize = 1; + for (broadcast_test.threads) |*t| { + t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); + thread_id += 1; + } + + { + broadcast_test.mutex.lock(); + defer broadcast_test.mutex.unlock(); + + // Wait for all the broadcast threads to spawn. + // timedWait() to detect any potential deadlocks. + while (broadcast_test.count != num_threads) { + try broadcast_test.completed.timedWait( + &broadcast_test.mutex, + 1 * std.time.ns_per_s, + ); + } + + // Signal thread 1 to wake up + broadcast_test.thread_id_to_wake = 1; + broadcast_test.cond.broadcast(); + } + + for (broadcast_test.threads) |t| { + t.join(); + } + //std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups}); +} -- cgit v1.2.3 From 9947b47d803415f40c82b6cbb510f47bc800658d Mon Sep 17 00:00:00 2001 From: Igor Anić Date: Mon, 21 Nov 2022 17:26:54 +0100 Subject: stdlib: Thread.Condition wake only if signaled Previous implementation didn't check whether there are pending signals after return from futex.wait. While it is ok for broadcast case it can result in multiple wakeups when only one thread is signaled. This implementation checks that there are pending signals before returning from wait. It is similar to the original implementation but the without initial signal check, here we first go to the futex and then check for pending signal. --- lib/std/Thread/Condition.zig | 220 +++++++++++++++++++++++++++++-------------- 1 file changed, 150 insertions(+), 70 deletions(-) (limited to 'lib/std/Thread') diff --git a/lib/std/Thread/Condition.zig b/lib/std/Thread/Condition.zig index 6829a9f15c..3625aab576 100644 --- a/lib/std/Thread/Condition.zig +++ b/lib/std/Thread/Condition.zig @@ -204,40 +204,44 @@ const FutexImpl = struct { // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) // // Acquire barrier to ensure the epoch load happens before the state load. - const epoch = self.epoch.load(.Acquire); + var epoch = self.epoch.load(.Acquire); var state = self.state.fetchAdd(one_waiter, .Monotonic); assert(state & waiter_mask != waiter_mask); state += one_waiter; - var futex_deadline = Futex.Deadline.init(timeout); mutex.unlock(); defer mutex.lock(); - futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { - // On timeout, we must decrement the waiter we added above. - error.Timeout => { - while (true) { - // If there's a signal when we're timing out, consume it and report being woken up instead. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; + var futex_deadline = Futex.Deadline.init(timeout); + + while (true) { + futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { + // On timeout, we must decrement the waiter we added above. + error.Timeout => { + while (true) { + // If there's a signal when we're timing out, consume it and report being woken up instead. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; + } + + // Remove the waiter we added and officially return timed out. + const new_state = state - one_waiter; + state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; } + }, + }; - // Remove the waiter we added and officially return timed out. - const new_state = state - one_waiter; - state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err; - } - }, - }; + epoch = self.epoch.load(.Acquire); + state = self.state.load(.Monotonic); - while (true) { - // Wait thread, decrement waiter and consume signal if exists. - var new_state = state - one_waiter; - if (state & signal_mask != 0) { - new_state = state - one_signal; + // Try to wake up by consuming a signal and decremented the waiter we added previously. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; } } @@ -535,66 +539,142 @@ test "Condition - broadcasting - wake all threads" { return error.SkipZigTest; } + var num_runs: usize = 1; const num_threads = 10; - const BroadcastTest = struct { - mutex: Mutex = .{}, - cond: Condition = .{}, - completed: Condition = .{}, - count: usize = 0, - thread_id_to_wake: usize = 0, - threads: [num_threads]std.Thread = undefined, - wakeups: usize = 0, - - fn run(self: *@This(), thread_id: usize) void { - self.mutex.lock(); - defer self.mutex.unlock(); + while (num_runs > 0) : (num_runs -= 1) { + const BroadcastTest = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + thread_id_to_wake: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + + fn run(self: *@This(), thread_id: usize) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last broadcast thread to start tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } - // The last broadcast thread to start tells the main test thread it's completed. - self.count += 1; - if (self.count == num_threads) { - self.completed.signal(); + while (self.thread_id_to_wake != thread_id) { + self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); + self.wakeups += 1; + } + if (self.thread_id_to_wake <= num_threads) { + // Signal next thread to wake up. + self.thread_id_to_wake += 1; + self.cond.broadcast(); + } } + }; - while (self.thread_id_to_wake != thread_id) { - self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); - self.wakeups += 1; - } - if (self.thread_id_to_wake <= num_threads) { - // Signal next thread to wake up. - self.thread_id_to_wake += 1; - self.cond.broadcast(); + var broadcast_test = BroadcastTest{}; + var thread_id: usize = 1; + for (broadcast_test.threads) |*t| { + t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); + thread_id += 1; + } + + { + broadcast_test.mutex.lock(); + defer broadcast_test.mutex.unlock(); + + // Wait for all the broadcast threads to spawn. + // timedWait() to detect any potential deadlocks. + while (broadcast_test.count != num_threads) { + try broadcast_test.completed.timedWait( + &broadcast_test.mutex, + 1 * std.time.ns_per_s, + ); } + + // Signal thread 1 to wake up + broadcast_test.thread_id_to_wake = 1; + broadcast_test.cond.broadcast(); } - }; - var broadcast_test = BroadcastTest{}; - var thread_id: usize = 1; - for (broadcast_test.threads) |*t| { - t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); - thread_id += 1; + for (broadcast_test.threads) |t| { + t.join(); + } } +} - { - broadcast_test.mutex.lock(); - defer broadcast_test.mutex.unlock(); +test "Condition - signal wakes one" { + // This test requires spawning threads + if (builtin.single_threaded) { + return error.SkipZigTest; + } - // Wait for all the broadcast threads to spawn. - // timedWait() to detect any potential deadlocks. - while (broadcast_test.count != num_threads) { - try broadcast_test.completed.timedWait( - &broadcast_test.mutex, - 1 * std.time.ns_per_s, - ); + var num_runs: usize = 1; + const num_threads = 3; + const timeoutDelay = 10 * std.time.ns_per_ms; + + while (num_runs > 0) : (num_runs -= 1) { + + // Start multiple runner threads, wait for them to start and send the signal + // then. Expect that one thread wake up and all other times out. + // + // Test depends on delay in timedWait! If too small all threads can timeout + // before any one gets wake up. + + const Runner = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + timeouts: usize = 0, + + fn run(self: *@This()) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last started thread tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + self.cond.timedWait(&self.mutex, timeoutDelay) catch { + self.timeouts += 1; + return; + }; + self.wakeups += 1; + } + }; + + // Start threads + var runner = Runner{}; + for (runner.threads) |*t| { + t.* = try std.Thread.spawn(.{}, Runner.run, .{&runner}); } - // Signal thread 1 to wake up - broadcast_test.thread_id_to_wake = 1; - broadcast_test.cond.broadcast(); - } + { + runner.mutex.lock(); + defer runner.mutex.unlock(); - for (broadcast_test.threads) |t| { - t.join(); + // Wait for all the threads to spawn. + // timedWait() to detect any potential deadlocks. + while (runner.count != num_threads) { + try runner.completed.timedWait(&runner.mutex, 1 * std.time.ns_per_s); + } + // Signal one thread, the others should get timeout. + runner.cond.signal(); + } + + for (runner.threads) |t| { + t.join(); + } + + // Expect that only one got singal + try std.testing.expectEqual(runner.wakeups, 1); + try std.testing.expectEqual(runner.timeouts, num_threads - 1); } - //std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups}); } -- cgit v1.2.3