diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2022-11-23 16:24:55 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 16:24:55 -0500 |
| commit | 1d7faf30f9c40cd1d4cb74480d0cc4e51c43ff0b (patch) | |
| tree | 6f9a5b4b626a2e759e38ff3062953f3bb98b105d /lib/std/Thread/Condition.zig | |
| parent | 258bee41bf58c9001ceaefc974fa594a26ac0fc5 (diff) | |
| parent | 9947b47d803415f40c82b6cbb510f47bc800658d (diff) | |
| download | zig-1d7faf30f9c40cd1d4cb74480d0cc4e51c43ff0b.tar.gz zig-1d7faf30f9c40cd1d4cb74480d0cc4e51c43ff0b.zip | |
Merge pull request #13577 from ianic/issue-12877
stdlib: fix condition variable broadcast FutexImpl
Diffstat (limited to 'lib/std/Thread/Condition.zig')
| -rw-r--r-- | lib/std/Thread/Condition.zig | 196 |
1 files changed, 169 insertions, 27 deletions
diff --git a/lib/std/Thread/Condition.zig b/lib/std/Thread/Condition.zig index 1482c8166d..3625aab576 100644 --- a/lib/std/Thread/Condition.zig +++ b/lib/std/Thread/Condition.zig @@ -194,42 +194,27 @@ const FutexImpl = struct { const signal_mask = 0xffff << 16; fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void { - // Register that we're waiting on the state by incrementing the wait count. - // This assumes that there can be at most ((1<<16)-1) or 65,355 threads concurrently waiting on the same Condvar. - // If this is hit in practice, then this condvar not working is the least of your concerns. + // Observe the epoch, then check the state again to see if we should wake up. + // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock: + // + // - T1: s = LOAD(&state) + // - T2: UPDATE(&s, signal) + // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch) + // - T1: e = LOAD(&epoch) (was reordered after the state load) + // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) + // + // Acquire barrier to ensure the epoch load happens before the state load. + var epoch = self.epoch.load(.Acquire); var state = self.state.fetchAdd(one_waiter, .Monotonic); assert(state & waiter_mask != waiter_mask); state += one_waiter; - // Temporarily release the mutex in order to block on the condition variable. mutex.unlock(); defer mutex.lock(); var futex_deadline = Futex.Deadline.init(timeout); - while (true) { - // Try to wake up by consuming a signal and decremented the waiter we added previously. - // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. - while (state & signal_mask != 0) { - const new_state = state - one_waiter - one_signal; - state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; - } - - // Observe the epoch, then check the state again to see if we should wake up. - // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock: - // - // - T1: s = LOAD(&state) - // - T2: UPDATE(&s, signal) - // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch) - // - T1: e = LOAD(&epoch) (was reordered after the state load) - // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change) - // - // Acquire barrier to ensure the epoch load happens before the state load. - const epoch = self.epoch.load(.Acquire); - state = self.state.load(.Monotonic); - if (state & signal_mask != 0) { - continue; - } + while (true) { futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) { // On timeout, we must decrement the waiter we added above. error.Timeout => { @@ -247,6 +232,16 @@ const FutexImpl = struct { } }, }; + + epoch = self.epoch.load(.Acquire); + state = self.state.load(.Monotonic); + + // Try to wake up by consuming a signal and decremented the waiter we added previously. + // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return. + while (state & signal_mask != 0) { + const new_state = state - one_waiter - one_signal; + state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return; + } } } @@ -536,3 +531,150 @@ test "Condition - broadcasting" { t.join(); } } + +test "Condition - broadcasting - wake all threads" { + // Tests issue #12877 + // This test requires spawning threads + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + var num_runs: usize = 1; + const num_threads = 10; + + while (num_runs > 0) : (num_runs -= 1) { + const BroadcastTest = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + thread_id_to_wake: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + + fn run(self: *@This(), thread_id: usize) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last broadcast thread to start tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + while (self.thread_id_to_wake != thread_id) { + self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake }); + self.wakeups += 1; + } + if (self.thread_id_to_wake <= num_threads) { + // Signal next thread to wake up. + self.thread_id_to_wake += 1; + self.cond.broadcast(); + } + } + }; + + var broadcast_test = BroadcastTest{}; + var thread_id: usize = 1; + for (broadcast_test.threads) |*t| { + t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id }); + thread_id += 1; + } + + { + broadcast_test.mutex.lock(); + defer broadcast_test.mutex.unlock(); + + // Wait for all the broadcast threads to spawn. + // timedWait() to detect any potential deadlocks. + while (broadcast_test.count != num_threads) { + try broadcast_test.completed.timedWait( + &broadcast_test.mutex, + 1 * std.time.ns_per_s, + ); + } + + // Signal thread 1 to wake up + broadcast_test.thread_id_to_wake = 1; + broadcast_test.cond.broadcast(); + } + + for (broadcast_test.threads) |t| { + t.join(); + } + } +} + +test "Condition - signal wakes one" { + // This test requires spawning threads + if (builtin.single_threaded) { + return error.SkipZigTest; + } + + var num_runs: usize = 1; + const num_threads = 3; + const timeoutDelay = 10 * std.time.ns_per_ms; + + while (num_runs > 0) : (num_runs -= 1) { + + // Start multiple runner threads, wait for them to start and send the signal + // then. Expect that one thread wake up and all other times out. + // + // Test depends on delay in timedWait! If too small all threads can timeout + // before any one gets wake up. + + const Runner = struct { + mutex: Mutex = .{}, + cond: Condition = .{}, + completed: Condition = .{}, + count: usize = 0, + threads: [num_threads]std.Thread = undefined, + wakeups: usize = 0, + timeouts: usize = 0, + + fn run(self: *@This()) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + // The last started thread tells the main test thread it's completed. + self.count += 1; + if (self.count == num_threads) { + self.completed.signal(); + } + + self.cond.timedWait(&self.mutex, timeoutDelay) catch { + self.timeouts += 1; + return; + }; + self.wakeups += 1; + } + }; + + // Start threads + var runner = Runner{}; + for (runner.threads) |*t| { + t.* = try std.Thread.spawn(.{}, Runner.run, .{&runner}); + } + + { + runner.mutex.lock(); + defer runner.mutex.unlock(); + + // Wait for all the threads to spawn. + // timedWait() to detect any potential deadlocks. + while (runner.count != num_threads) { + try runner.completed.timedWait(&runner.mutex, 1 * std.time.ns_per_s); + } + // Signal one thread, the others should get timeout. + runner.cond.signal(); + } + + for (runner.threads) |t| { + t.join(); + } + + // Expect that only one got singal + try std.testing.expectEqual(runner.wakeups, 1); + try std.testing.expectEqual(runner.timeouts, num_threads - 1); + } +} |
