From 5d9eeff062053f87ab900fc7ebde6eb13a2a1645 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Wed, 21 Apr 2021 13:45:57 -0700 Subject: [PATCH 1/3] Ensure TLS destructors run before thread joins in SGX --- library/std/src/sys/sgx/abi/mod.rs | 6 ++- library/std/src/sys/sgx/thread.rs | 69 +++++++++++++++++++++++---- library/std/src/thread/local/tests.rs | 55 ++++++++++++++++++++- 3 files changed, 119 insertions(+), 11 deletions(-) diff --git a/library/std/src/sys/sgx/abi/mod.rs b/library/std/src/sys/sgx/abi/mod.rs index a5e453034762c..f9536c4203df2 100644 --- a/library/std/src/sys/sgx/abi/mod.rs +++ b/library/std/src/sys/sgx/abi/mod.rs @@ -62,10 +62,12 @@ unsafe extern "C" fn tcs_init(secondary: bool) { extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn { // FIXME: how to support TLS in library mode? let tls = Box::new(tls::Tls::new()); - let _tls_guard = unsafe { tls.activate() }; + let tls_guard = unsafe { tls.activate() }; if secondary { - super::thread::Thread::entry(); + let join_notifier = super::thread::Thread::entry(); + drop(tls_guard); + drop(join_notifier); EntryReturn(0, 0) } else { diff --git a/library/std/src/sys/sgx/thread.rs b/library/std/src/sys/sgx/thread.rs index 55ef460cc90c5..67e2e8b59d397 100644 --- a/library/std/src/sys/sgx/thread.rs +++ b/library/std/src/sys/sgx/thread.rs @@ -9,26 +9,37 @@ pub struct Thread(task_queue::JoinHandle); pub const DEFAULT_MIN_STACK_SIZE: usize = 4096; +pub use self::task_queue::JoinNotifier; + mod task_queue { - use crate::sync::mpsc; + use super::wait_notify; use crate::sync::{Mutex, MutexGuard, Once}; - pub type JoinHandle = mpsc::Receiver<()>; + pub type JoinHandle = wait_notify::Waiter; + + pub struct JoinNotifier(Option); + + impl Drop for JoinNotifier { + fn drop(&mut self) { + self.0.take().unwrap().notify(); + } + } pub(super) struct Task { p: Box, - done: mpsc::Sender<()>, + done: JoinNotifier, } impl Task { pub(super) fn new(p: Box) -> (Task, JoinHandle) { - let (done, recv) = mpsc::channel(); + let (done, recv) = wait_notify::new(); + let done = JoinNotifier(Some(done)); (Task { p, done }, recv) } - pub(super) fn run(self) { + pub(super) fn run(self) -> JoinNotifier { (self.p)(); - let _ = self.done.send(()); + self.done } } @@ -47,6 +58,48 @@ mod task_queue { } } +/// This module provides a synchronization primitive that does not use thread +/// local variables. This is needed for signaling that a thread has finished +/// execution. The signal is sent once all TLS destructors have finished at +/// which point no new thread locals should be created. +pub mod wait_notify { + use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable}; + use crate::sync::Arc; + + pub struct Notifier(Arc>>); + + impl Notifier { + /// Notify the waiter. The waiter is either notified right away (if + /// currently blocked in `Waiter::wait()`) or later when it calls the + /// `Waiter::wait()` method. + pub fn notify(self) { + let mut guard = self.0.lock(); + *guard.lock_var_mut() = true; + let _ = WaitQueue::notify_one(guard); + } + } + + pub struct Waiter(Arc>>); + + impl Waiter { + /// Wait for a notification. If `Notifier::notify()` has already been + /// called, this will return immediately, otherwise the current thread + /// is blocked until notified. + pub fn wait(self) { + let guard = self.0.lock(); + if *guard.lock_var() { + return; + } + WaitQueue::wait(guard, || {}); + } + } + + pub fn new() -> (Notifier, Waiter) { + let inner = Arc::new(SpinMutex::new(WaitVariable::new(false))); + (Notifier(inner.clone()), Waiter(inner)) + } +} + impl Thread { // unsafe: see thread::Builder::spawn_unchecked for safety requirements pub unsafe fn new(_stack: usize, p: Box) -> io::Result { @@ -57,7 +110,7 @@ impl Thread { Ok(Thread(handle)) } - pub(super) fn entry() { + pub(super) fn entry() -> JoinNotifier { let mut pending_tasks = task_queue::lock(); let task = rtunwrap!(Some, pending_tasks.pop()); drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary @@ -78,7 +131,7 @@ impl Thread { } pub fn join(self) { - let _ = self.0.recv(); + self.0.wait(); } } diff --git a/library/std/src/thread/local/tests.rs b/library/std/src/thread/local/tests.rs index 80e6798d847b1..98f525eafb0f6 100644 --- a/library/std/src/thread/local/tests.rs +++ b/library/std/src/thread/local/tests.rs @@ -1,5 +1,6 @@ use crate::cell::{Cell, UnsafeCell}; -use crate::sync::mpsc::{channel, Sender}; +use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::sync::mpsc::{self, channel, Sender}; use crate::thread::{self, LocalKey}; use crate::thread_local; @@ -207,3 +208,55 @@ fn dtors_in_dtors_in_dtors_const_init() { }); rx.recv().unwrap(); } + +// This test tests that TLS destructors have run before the thread joins. The +// test has no false positives (meaning: if the test fails, there's actually +// an ordering problem). It may have false negatives, where the test passes but +// join is not guaranteed to be after the TLS destructors. However, false +// negatives should be exceedingly rare due to judicious use of +// thread::yield_now and running the test several times. +#[test] +fn join_orders_after_tls_destructors() { + static THREAD2_LAUNCHED: AtomicBool = AtomicBool::new(false); + + for _ in 0..10 { + let (tx, rx) = mpsc::sync_channel(0); + THREAD2_LAUNCHED.store(false, Ordering::SeqCst); + + let jh = thread::spawn(move || { + struct RecvOnDrop(Cell>>); + + impl Drop for RecvOnDrop { + fn drop(&mut self) { + let rx = self.0.take().unwrap(); + while !THREAD2_LAUNCHED.load(Ordering::SeqCst) { + thread::yield_now(); + } + rx.recv().unwrap(); + } + } + + thread_local! { + static TL_RX: RecvOnDrop = RecvOnDrop(Cell::new(None)); + } + + TL_RX.with(|v| v.0.set(Some(rx))) + }); + + let tx_clone = tx.clone(); + let jh2 = thread::spawn(move || { + THREAD2_LAUNCHED.store(true, Ordering::SeqCst); + jh.join().unwrap(); + tx_clone.send(()).expect_err( + "Expecting channel to be closed because thread 1 TLS destructors must've run", + ); + }); + + while !THREAD2_LAUNCHED.load(Ordering::SeqCst) { + thread::yield_now(); + } + thread::yield_now(); + tx.send(()).expect("Expecting channel to be live because thread 2 must block on join"); + jh2.join().unwrap(); + } +} From 8a0a4b1493264901e5e41c4285fc3cc9f8419c28 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Thu, 29 Apr 2021 08:44:45 -0700 Subject: [PATCH 2/3] Use atomics in join_orders_after_tls_destructors test std::sync::mpsc uses thread locals and depending on the order TLS dtors are run `rx.recv()` can panic when used in a TLS dtor. --- library/std/src/thread/local/tests.rs | 122 +++++++++++++++++++------- 1 file changed, 88 insertions(+), 34 deletions(-) diff --git a/library/std/src/thread/local/tests.rs b/library/std/src/thread/local/tests.rs index 98f525eafb0f6..494ad4e5fda9a 100644 --- a/library/std/src/thread/local/tests.rs +++ b/library/std/src/thread/local/tests.rs @@ -1,6 +1,6 @@ use crate::cell::{Cell, UnsafeCell}; -use crate::sync::atomic::{AtomicBool, Ordering}; -use crate::sync::mpsc::{self, channel, Sender}; +use crate::sync::atomic::{AtomicU8, Ordering}; +use crate::sync::mpsc::{channel, Sender}; use crate::thread::{self, LocalKey}; use crate::thread_local; @@ -217,46 +217,100 @@ fn dtors_in_dtors_in_dtors_const_init() { // thread::yield_now and running the test several times. #[test] fn join_orders_after_tls_destructors() { - static THREAD2_LAUNCHED: AtomicBool = AtomicBool::new(false); + // We emulate a synchronous MPSC rendezvous channel using only atomics and + // thread::yield_now. We can't use std::mpsc as the implementation itself + // may rely on thread locals. + // + // The basic state machine for an SPSC rendezvous channel is: + // FRESH -> THREAD1_WAITING -> MAIN_THREAD_RENDEZVOUS + // where the first transition is done by the “receiving” thread and the 2nd + // transition is done by the “sending” thread. + // + // We add an additional state `THREAD2_LAUNCHED` between `FRESH` and + // `THREAD1_WAITING` to block until all threads are actually running. + // + // A thread that joins on the “receiving” thread completion should never + // observe the channel in the `THREAD1_WAITING` state. If this does occur, + // we switch to the “poison” state `THREAD2_JOINED` and panic all around. + // (This is equivalent to “sending” from an alternate producer thread.) + const FRESH: u8 = 0; + const THREAD2_LAUNCHED: u8 = 1; + const THREAD1_WAITING: u8 = 2; + const MAIN_THREAD_RENDEZVOUS: u8 = 3; + const THREAD2_JOINED: u8 = 4; + static SYNC_STATE: AtomicU8 = AtomicU8::new(FRESH); for _ in 0..10 { - let (tx, rx) = mpsc::sync_channel(0); - THREAD2_LAUNCHED.store(false, Ordering::SeqCst); + SYNC_STATE.store(FRESH, Ordering::SeqCst); + + let jh = thread::Builder::new() + .name("thread1".into()) + .spawn(move || { + struct TlDrop; + + impl Drop for TlDrop { + fn drop(&mut self) { + loop { + match SYNC_STATE.load(Ordering::SeqCst) { + FRESH => thread::yield_now(), + THREAD2_LAUNCHED => break, + v => unreachable!("sync state: {}", v), + } + } + let mut sync_state = SYNC_STATE.swap(THREAD1_WAITING, Ordering::SeqCst); + loop { + match sync_state { + THREAD2_LAUNCHED | THREAD1_WAITING => thread::yield_now(), + MAIN_THREAD_RENDEZVOUS => break, + THREAD2_JOINED => panic!( + "Thread 1 still running after thread 2 joined on thread 1" + ), + v => unreachable!("sync state: {}", v), + } + sync_state = SYNC_STATE.load(Ordering::SeqCst); + } + } + } - let jh = thread::spawn(move || { - struct RecvOnDrop(Cell>>); + thread_local! { + static TL_DROP: TlDrop = TlDrop; + } - impl Drop for RecvOnDrop { - fn drop(&mut self) { - let rx = self.0.take().unwrap(); - while !THREAD2_LAUNCHED.load(Ordering::SeqCst) { - thread::yield_now(); + TL_DROP.with(|_| {}) + }) + .unwrap(); + + let jh2 = thread::Builder::new() + .name("thread2".into()) + .spawn(move || { + assert_eq!(SYNC_STATE.swap(THREAD2_LAUNCHED, Ordering::SeqCst), FRESH); + jh.join().unwrap(); + match SYNC_STATE.swap(THREAD2_JOINED, Ordering::SeqCst) { + MAIN_THREAD_RENDEZVOUS => return, + THREAD2_LAUNCHED | THREAD1_WAITING => { + panic!("Thread 2 running after thread 1 join before main thread rendezvous") } - rx.recv().unwrap(); + v => unreachable!("sync state: {:?}", v), } + }) + .unwrap(); + + loop { + match SYNC_STATE.compare_exchange_weak( + THREAD1_WAITING, + MAIN_THREAD_RENDEZVOUS, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => break, + Err(FRESH) => thread::yield_now(), + Err(THREAD2_LAUNCHED) => thread::yield_now(), + Err(THREAD2_JOINED) => { + panic!("Main thread rendezvous after thread 2 joined thread 1") + } + v => unreachable!("sync state: {:?}", v), } - - thread_local! { - static TL_RX: RecvOnDrop = RecvOnDrop(Cell::new(None)); - } - - TL_RX.with(|v| v.0.set(Some(rx))) - }); - - let tx_clone = tx.clone(); - let jh2 = thread::spawn(move || { - THREAD2_LAUNCHED.store(true, Ordering::SeqCst); - jh.join().unwrap(); - tx_clone.send(()).expect_err( - "Expecting channel to be closed because thread 1 TLS destructors must've run", - ); - }); - - while !THREAD2_LAUNCHED.load(Ordering::SeqCst) { - thread::yield_now(); } - thread::yield_now(); - tx.send(()).expect("Expecting channel to be live because thread 2 must block on join"); jh2.join().unwrap(); } } From 2acd62d7c389bcdcf212673ed14120d2fd841df6 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Thu, 6 May 2021 09:36:26 -0700 Subject: [PATCH 3/3] join_orders_after_tls_destructors: ensure thread 2 is launched before thread 1 enters TLS destructors --- library/std/src/thread/local/tests.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/library/std/src/thread/local/tests.rs b/library/std/src/thread/local/tests.rs index 494ad4e5fda9a..f33d612961931 100644 --- a/library/std/src/thread/local/tests.rs +++ b/library/std/src/thread/local/tests.rs @@ -250,13 +250,6 @@ fn join_orders_after_tls_destructors() { impl Drop for TlDrop { fn drop(&mut self) { - loop { - match SYNC_STATE.load(Ordering::SeqCst) { - FRESH => thread::yield_now(), - THREAD2_LAUNCHED => break, - v => unreachable!("sync state: {}", v), - } - } let mut sync_state = SYNC_STATE.swap(THREAD1_WAITING, Ordering::SeqCst); loop { match sync_state { @@ -276,7 +269,15 @@ fn join_orders_after_tls_destructors() { static TL_DROP: TlDrop = TlDrop; } - TL_DROP.with(|_| {}) + TL_DROP.with(|_| {}); + + loop { + match SYNC_STATE.load(Ordering::SeqCst) { + FRESH => thread::yield_now(), + THREAD2_LAUNCHED => break, + v => unreachable!("sync state: {}", v), + } + } }) .unwrap();