From 43abbc964e9cf296a852c22cd452405b82e76ee6 Mon Sep 17 00:00:00 2001 From: Amanieu d'Antras Date: Sat, 21 May 2016 20:44:36 +0100 Subject: [PATCH] Add Windows keyed event implementation of ThreadParker --- Cargo.toml | 4 + src/lib.rs | 10 +- src/parking_lot.rs | 126 +++++++++---------- src/thread_parker/windows.rs | 226 +++++++++++++++++++++++++++++++++++ src/word_lock.rs | 38 +++--- 5 files changed, 321 insertions(+), 83 deletions(-) create mode 100644 src/thread_parker/windows.rs diff --git a/Cargo.toml b/Cargo.toml index 88a9cb97..a269c3b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,10 @@ smallvec = "0.1" [target.'cfg(target_os = "linux")'.dependencies] libc = "0.2" +[target.'cfg(windows)'.dependencies] +winapi = "0.2" +kernel32-sys = "0.2" + [dev-dependencies] rand = "0.3" lazy_static = "0.2" diff --git a/src/lib.rs b/src/lib.rs index e0a2899c..ebc2b0ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,13 +111,21 @@ extern crate lazy_static; #[cfg(all(feature = "nightly", target_os = "linux"))] extern crate libc; +#[cfg(windows)] +extern crate winapi; +#[cfg(windows)] +extern crate kernel32; + // Spin limit from JikesRVM & Webkit experiments const SPIN_LIMIT: usize = 40; #[cfg(all(feature = "nightly", target_os = "linux"))] #[path = "thread_parker/linux.rs"] mod thread_parker; -#[cfg(not(all(feature = "nightly", target_os = "linux")))] +#[cfg(windows)] +#[path = "thread_parker/windows.rs"] +mod thread_parker; +#[cfg(not(any(windows, all(feature = "nightly", target_os = "linux"))))] #[path = "thread_parker/generic.rs"] mod thread_parker; diff --git a/src/parking_lot.rs b/src/parking_lot.rs index 65227b7a..d5abaa72 100644 --- a/src/parking_lot.rs +++ b/src/parking_lot.rs @@ -268,81 +268,81 @@ pub unsafe fn park(key: usize, timeout: Option) -> bool { // Grab our thread data, this also ensures that the hash table exists - THREAD_DATA.with(|thread_data| { - // Lock the bucket for the given key - let bucket = lock_bucket(key).unwrap(); + let thread_data = &*THREAD_DATA.with(|x| x as *const ThreadData); - // If the validation function fails, just return - if !validate() { - bucket.mutex.unlock(); - return false; - } + // Lock the bucket for the given key + let bucket = lock_bucket(key).unwrap(); - // Append our thread data to the queue and unlock the bucket - thread_data.next_in_queue.set(ptr::null()); - thread_data.key.set(key); - thread_data.parker.prepare_park(); - if !bucket.queue_head.get().is_null() { - (*bucket.queue_tail.get()).next_in_queue.set(thread_data); - } else { - bucket.queue_head.set(thread_data); - } - bucket.queue_tail.set(thread_data); + // If the validation function fails, just return + if !validate() { bucket.mutex.unlock(); + return false; + } - // Invoke the pre-sleep callback - before_sleep(); - - // Park our thread and determine whether we were woken up by an unpark - // or by our timeout. Note that this isn't precise: we can still be - // unparked since we are still in the queue. - let unparked = match timeout { - Some(timeout) => thread_data.parker.park_until(timeout), - None => { - thread_data.parker.park(); - true - } - }; + // Append our thread data to the queue and unlock the bucket + thread_data.next_in_queue.set(ptr::null()); + thread_data.key.set(key); + thread_data.parker.prepare_park(); + if !bucket.queue_head.get().is_null() { + (*bucket.queue_tail.get()).next_in_queue.set(thread_data); + } else { + bucket.queue_head.set(thread_data); + } + bucket.queue_tail.set(thread_data); + bucket.mutex.unlock(); - // If we were unparked, return now - if unparked { - return true; + // Invoke the pre-sleep callback + before_sleep(); + + // Park our thread and determine whether we were woken up by an unpark or by + // our timeout. Note that this isn't precise: we can still be unparked + // since we are still in the queue. + let unparked = match timeout { + Some(timeout) => thread_data.parker.park_until(timeout), + None => { + thread_data.parker.park(); + true } + }; - // Lock our bucket again. Note that the hashtable may have been rehashed - // in the meantime. - let bucket = lock_bucket(key).unwrap(); + // If we were unparked, return now + if unparked { + return true; + } - // Now we need to check again if we were unparked or timed out. Unlike - // the last check this is precise because we hold the bucket lock. - if !thread_data.parker.timed_out() { - bucket.mutex.unlock(); - return true; - } + // Lock our bucket again. Note that the hashtable may have been rehashed in + // the meantime. + let bucket = lock_bucket(key).unwrap(); - // We timed out, so we now need to remove our thread from the queue - let mut link = &bucket.queue_head; - let mut current = bucket.queue_head.get(); - let mut previous = ptr::null(); - while !current.is_null() { - if current == thread_data { - let next = (*current).next_in_queue.get(); - link.set(next); - if bucket.queue_tail.get() == current { - bucket.queue_tail.set(previous); - } - break; - } else { - link = &(*current).next_in_queue; - previous = current; - current = link.get(); + // Now we need to check again if we were unparked or timed out. Unlike the + // last check this is precise because we hold the bucket lock. + if !thread_data.parker.timed_out() { + bucket.mutex.unlock(); + return true; + } + + // We timed out, so we now need to remove our thread from the queue + let mut link = &bucket.queue_head; + let mut current = bucket.queue_head.get(); + let mut previous = ptr::null(); + while !current.is_null() { + if current == thread_data { + let next = (*current).next_in_queue.get(); + link.set(next); + if bucket.queue_tail.get() == current { + bucket.queue_tail.set(previous); } + break; + } else { + link = &(*current).next_in_queue; + previous = current; + current = link.get(); } + } - // Unlock the bucket, we are done - bucket.mutex.unlock(); - false - }) + // Unlock the bucket, we are done + bucket.mutex.unlock(); + false } /// Unparks one thread from the queue associated with the given key. diff --git a/src/thread_parker/windows.rs b/src/thread_parker/windows.rs new file mode 100644 index 00000000..6801f42f --- /dev/null +++ b/src/thread_parker/windows.rs @@ -0,0 +1,226 @@ +// Copyright 2016 Amanieu d'Antras +// +// Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +#[cfg(feature = "nightly")] +use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering}; +#[cfg(not(feature = "nightly"))] +use stable::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering}; +use std::time::Instant; +use std::ptr; +use std::mem; +use winapi; +use kernel32; + +#[allow(non_snake_case)] +struct KeyedEvent { + handle: winapi::HANDLE, + NtReleaseKeyedEvent: extern "system" fn(EventHandle: winapi::HANDLE, + Key: winapi::PVOID, + Alertable: winapi::BOOLEAN, + Timeout: winapi::PLARGE_INTEGER) + -> winapi::NTSTATUS, + NtWaitForKeyedEvent: extern "system" fn(EventHandle: winapi::HANDLE, + Key: winapi::PVOID, + Alertable: winapi::BOOLEAN, + Timeout: winapi::PLARGE_INTEGER) + -> winapi::NTSTATUS, +} + +impl KeyedEvent { + unsafe fn wait_for(&self, + key: winapi::PVOID, + timeout: winapi::PLARGE_INTEGER) + -> winapi::NTSTATUS { + (self.NtWaitForKeyedEvent)(self.handle, key, 0, timeout) + } + + unsafe fn release(&self, key: winapi::PVOID) -> winapi::NTSTATUS { + (self.NtReleaseKeyedEvent)(self.handle, key, 0, ptr::null_mut()) + } + + unsafe fn get() -> &'static KeyedEvent { + static KEYED_EVENT: AtomicUsize = ATOMIC_USIZE_INIT; + + // Fast path: use the existing object + let keyed_event = KEYED_EVENT.load(Ordering::Acquire); + if keyed_event != 0 { + return &*(keyed_event as *const KeyedEvent); + }; + + // Try to create a new object + let keyed_event = Box::into_raw(KeyedEvent::create()); + match KEYED_EVENT.compare_exchange(0, + keyed_event as usize, + Ordering::Release, + Ordering::Relaxed) { + Ok(_) => &*(keyed_event as *const KeyedEvent), + Err(x) => { + // We lost the race, free our object and return the global one + Box::from_raw(keyed_event); + &*(x as *const KeyedEvent) + } + } + } + + #[allow(non_snake_case)] + unsafe fn create() -> Box { + let ntdll = kernel32::GetModuleHandleA(b"ntdll.dll".as_ptr() as winapi::LPCSTR); + if ntdll.is_null() { + panic!("Could not get module handle for ntdll.dll"); + } + + let NtCreateKeyedEvent = + kernel32::GetProcAddress(ntdll, b"NtCreateKeyedEvent".as_ptr() as winapi::LPCSTR); + if NtCreateKeyedEvent.is_null() { + panic!("Entry point NtCreateKeyedEvent not found in ntdll.dll"); + } + let NtReleaseKeyedEvent = + kernel32::GetProcAddress(ntdll, b"NtReleaseKeyedEvent".as_ptr() as winapi::LPCSTR); + if NtReleaseKeyedEvent.is_null() { + panic!("Entry point NtReleaseKeyedEvent not found in ntdll.dll"); + } + let NtWaitForKeyedEvent = + kernel32::GetProcAddress(ntdll, b"NtWaitForKeyedEvent".as_ptr() as winapi::LPCSTR); + if NtWaitForKeyedEvent.is_null() { + panic!("Entry point NtWaitForKeyedEvent not found in ntdll.dll"); + } + + let NtCreateKeyedEvent: extern "system" fn(KeyedEventHandle: winapi::PHANDLE, + DesiredAccess: winapi::ACCESS_MASK, + ObjectAttributes: winapi::PVOID, + Flags: winapi::ULONG) + -> winapi::NTSTATUS = + mem::transmute(NtCreateKeyedEvent); + let mut handle = mem::uninitialized(); + let status = NtCreateKeyedEvent(&mut handle, + winapi::GENERIC_READ | winapi::GENERIC_WRITE, + ptr::null_mut(), + 0); + if status != winapi::STATUS_SUCCESS { + panic!("NtCreateKeyedEvent failed: {:x}", status); + } + + Box::new(KeyedEvent { + handle: handle, + NtReleaseKeyedEvent: mem::transmute(NtReleaseKeyedEvent), + NtWaitForKeyedEvent: mem::transmute(NtWaitForKeyedEvent), + }) + } +} + +impl Drop for KeyedEvent { + fn drop(&mut self) { + unsafe { + let ok = kernel32::CloseHandle(self.handle); + debug_assert_eq!(ok, winapi::TRUE); + } + } +} + +// Helper type for putting a thread to sleep until some other thread wakes it up +pub struct ThreadParker { + key: AtomicUsize, + keyed_event: &'static KeyedEvent, +} + +impl ThreadParker { + pub fn new() -> ThreadParker { + // Initialize the keyed event here to ensure we don't get any panics + // later on, which could leave synchronization primitives in a broken + // state. + ThreadParker { + key: AtomicUsize::new(0), + keyed_event: unsafe { KeyedEvent::get() }, + } + } + + // Prepares the parker. This should be called before adding it to the queue. + pub fn prepare_park(&self) { + self.key.store(1, Ordering::Relaxed); + } + + // Checks if the park timed out. This should be called while holding the + // queue lock after park_until has returned false. + pub fn timed_out(&self) -> bool { + self.key.load(Ordering::Relaxed) != 0 + } + + // Parks the thread until it is unparked. This should be called after it has + // been added to the queue, after unlocking the queue. + pub fn park(&self) { + let status = unsafe { + self.keyed_event.wait_for(self as *const _ as winapi::PVOID, ptr::null_mut()) + }; + debug_assert_eq!(status, winapi::STATUS_SUCCESS); + } + + // Parks the thread until it is unparked or the timeout is reached. This + // should be called after it has been added to the queue, after unlocking + // the queue. Returns true if we were unparked and false if we timed out. + pub fn park_until(&self, timeout: Instant) -> bool { + let now = Instant::now(); + if timeout <= now { + // If another thread unparked us, we need to call + // NtWaitForKeyedEvent otherwise that thread will stay stuck at + // NtReleaseKeyedEvent. + if self.key.swap(2, Ordering::Relaxed) == 0 { + self.park(); + return true; + } + return false; + } + + // NT uses a timeout in units of 100ns. We use a negative value to + // indicate a relative timeout based on a monotonic clock. + let diff = timeout - now; + let nt_timeout = (diff.as_secs() as winapi::LARGE_INTEGER) + .checked_mul(-10000000) + .and_then(|x| x.checked_sub((diff.subsec_nanos() as winapi::LARGE_INTEGER + 99) / 100)); + let mut nt_timeout = match nt_timeout { + Some(x) => x, + None => { + // Timeout overflowed, just sleep indefinitely + self.park(); + return true; + } + }; + + let status = unsafe { + self.keyed_event.wait_for(self as *const _ as winapi::PVOID, &mut nt_timeout) + }; + if status == winapi::STATUS_SUCCESS { + return true; + } + debug_assert_eq!(status, winapi::STATUS_TIMEOUT); + + // If another thread unparked us, we need to call + // NtWaitForKeyedEvent otherwise that thread will stay stuck at + // NtReleaseKeyedEvent. + if self.key.swap(2, Ordering::Relaxed) == 0 { + self.park(); + return true; + } + false + } + + // Lock the parker to prevent the target thread from exiting. This is + // necessary to ensure that thread-local ThreadData objects remain valid. + // This should be called while holding the queue lock. + pub fn unpark_lock(&self) -> bool { + // If the state was 1 then we need to wake up the thread + self.key.swap(0, Ordering::Relaxed) == 1 + } + + // Wakes up the parked thread. This should be called after the queue lock is + // released to avoid blocking the queue for too long. + pub fn unpark(&self, need_wakeup: bool) { + if need_wakeup { + let status = unsafe { self.keyed_event.release(self as *const _ as winapi::PVOID) }; + debug_assert_eq!(status, winapi::STATUS_SUCCESS); + } + } +} diff --git a/src/word_lock.rs b/src/word_lock.rs index 1b65cd74..87b0b9d6 100644 --- a/src/word_lock.rs +++ b/src/word_lock.rs @@ -107,6 +107,14 @@ impl WordLock { continue; } + // Get our thread data. We do this before locking the queue because + // the ThreadData constructor may panic and we don't want to leave + // the queue in a locked state. + let thread_data = &*THREAD_DATA.with(|x| x as *const ThreadData); + assert!(mem::align_of_val(thread_data) > !QUEUE_MASK); + thread_data.next_in_queue.set(ptr::null()); + thread_data.parker.prepare_park(); + // Try locking the queue if let Err(x) = self.state .compare_exchange_weak(state, @@ -117,26 +125,18 @@ impl WordLock { continue; } - // Get our thread data - THREAD_DATA.with(|thread_data| { - assert!(mem::align_of_val(thread_data) > !QUEUE_MASK); - - // Add our thread to the queue and unlock the queue - thread_data.next_in_queue.set(ptr::null()); - thread_data.parker.prepare_park(); - let mut queue_head = (state & QUEUE_MASK) as *const ThreadData; - if !queue_head.is_null() { - (*(*queue_head).queue_tail.get()).next_in_queue.set(thread_data); - } else { - queue_head = thread_data; - } - (*queue_head).queue_tail.set(thread_data); - self.state.store((queue_head as usize) | LOCKED_BIT, Ordering::Release); - - // Sleep until we are woken up by an unlock - thread_data.parker.park(); - }); + // Add our thread to the queue and unlock the queue + let mut queue_head = (state & QUEUE_MASK) as *const ThreadData; + if !queue_head.is_null() { + (*(*queue_head).queue_tail.get()).next_in_queue.set(thread_data); + } else { + queue_head = thread_data; + } + (*queue_head).queue_tail.set(thread_data); + self.state.store((queue_head as usize) | LOCKED_BIT, Ordering::Release); + // Sleep until we are woken up by an unlock + thread_data.parker.park(); self.state.load(Ordering::Relaxed); } }