diff --git a/Cargo.lock b/Cargo.lock index 974b7c40cc5..cf566161b60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -980,6 +980,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if 1.0.0", + "hashbrown 0.12.3", + "lock_api", + "once_cell", + "parking_lot_core 0.9.7", +] + [[package]] name = "derivative" version = "2.2.0" @@ -5737,8 +5750,10 @@ dependencies = [ "cc", "cfg-if 1.0.0", "corosensei", + "dashmap", "derivative", "enum-iterator", + "fnv", "indexmap", "lazy_static", "libc", diff --git a/lib/vm/Cargo.toml b/lib/vm/Cargo.toml index a7e952f7803..99fe366d6f3 100644 --- a/lib/vm/Cargo.toml +++ b/lib/vm/Cargo.toml @@ -26,6 +26,8 @@ lazy_static = "1.4.0" region = { version = "3.0" } corosensei = { version = "0.1.2" } derivative = { version = "^2" } +dashmap = { version = "5.4" } +fnv = "1.0.3" # - Optional shared dependencies. tracing = { version = "0.1", optional = true } diff --git a/lib/vm/src/instance/mod.rs b/lib/vm/src/instance/mod.rs index bfc62fca674..0909e6927f5 100644 --- a/lib/vm/src/instance/mod.rs +++ b/lib/vm/src/instance/mod.rs @@ -19,8 +19,8 @@ use crate::vmcontext::{ VMFunctionImport, VMFunctionKind, VMGlobalDefinition, VMGlobalImport, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, VMTableImport, VMTrampoline, }; -use crate::LinearMemory; use crate::{FunctionBodyPtr, MaybeInstanceOwned, TrapHandlerFn, VMFunctionBody}; +use crate::{LinearMemory, NotifyLocation}; use crate::{VMFuncRef, VMFunction, VMGlobal, VMMemory, VMTable}; pub use allocator::InstanceAllocator; use memoffset::offset_of; @@ -33,8 +33,7 @@ use std::fmt; use std::mem; use std::ptr::{self, NonNull}; use std::slice; -use std::sync::{Arc, Mutex}; -use std::thread::{current, park, park_timeout, Thread}; +use std::sync::Arc; use wasmer_types::entity::{packed_option::ReservedValue, BoxedSlice, EntityRef, PrimaryMap}; use wasmer_types::{ DataIndex, DataInitializer, ElemIndex, ExportIndex, FunctionIndex, GlobalIndex, GlobalInit, @@ -42,20 +41,6 @@ use wasmer_types::{ MemoryIndex, ModuleInfo, Pages, SignatureIndex, TableIndex, TableInitializer, VMOffsets, }; -#[derive(Hash, Eq, PartialEq, Clone, Copy)] -struct NotifyLocation { - memory_index: u32, - address: u32, -} - -struct NotifyWaiter { - thread: Thread, - notified: bool, -} -struct NotifyMap { - map: HashMap>, -} - /// A WebAssembly instance. /// /// The type is dynamically-sized. Indeed, the `vmctx` field can @@ -105,9 +90,6 @@ pub(crate) struct Instance { /// will point to elements here for functions imported by this instance. imported_funcrefs: BoxedSlice>, - /// The Hasmap with the Notify for the Notify/wait opcodes - conditions: Arc>, - /// Additional context used by compiled WebAssembly code. This /// field is last, and represents a dynamically-sized array that /// extends beyond the nominal end of the struct (similar to a @@ -276,6 +258,31 @@ impl Instance { } } + /// Get a locally defined or imported memory. + fn get_vmmemory_mut(&mut self, index: MemoryIndex) -> &mut VMMemory { + if let Some(local_index) = self.module.local_memory_index(index) { + unsafe { + self.memories + .get_mut(local_index) + .unwrap() + .get_mut(self.context.as_mut().unwrap()) + } + } else { + let import = self.imported_memory(index); + unsafe { import.handle.get_mut(self.context.as_mut().unwrap()) } + } + } + + /// Get a locally defined memory as mutable. + fn get_local_vmmemory_mut(&mut self, local_index: LocalMemoryIndex) -> &mut VMMemory { + unsafe { + self.memories + .get_mut(local_index) + .unwrap() + .get_mut(self.context.as_mut().unwrap()) + } + } + /// Return the indexed `VMGlobalDefinition`. fn global(&self, index: LocalGlobalIndex) -> VMGlobalDefinition { unsafe { self.global_ptr(index).as_ref().clone() } @@ -797,51 +804,19 @@ impl Instance { } } - // To implement Wait / Notify, a HasMap, behind a mutex, will be used - // to track the address of waiter. The key of the hashmap is based on the memory - // and waiter threads are "park"'d (with or without timeout) - // Notify will wake the waiters by simply "unpark" the thread - // as the Thread info is stored on the HashMap - // once unparked, the waiter thread will remove it's mark on the HashMap - // timeout / awake is tracked with a boolean in the HashMap - // because `park_timeout` doesn't gives any information on why it returns - fn do_wait(&mut self, index: u32, dst: u32, timeout: i64) -> u32 { - // fetch the notifier - let key = NotifyLocation { - memory_index: index, - address: dst, - }; - let mut conds = self.conditions.lock().unwrap(); - let v = conds.map.entry(key).or_insert_with(Vec::new); - v.push(NotifyWaiter { - thread: current(), - notified: false, - }); - drop(conds); - if timeout < 0 { - park(); + fn memory_wait(memory: &mut VMMemory, dst: u32, timeout: i64) -> Result { + let location = NotifyLocation { address: dst }; + let timeout = if timeout < 0 { + None } else { - park_timeout(std::time::Duration::from_nanos(timeout as u64)); - } - let mut conds = self.conditions.lock().unwrap(); - let v = conds.map.get_mut(&key).unwrap(); - let id = current().id(); - let mut ret = 0; - v.retain(|cond| { - if cond.thread.id() == id { - ret = if cond.notified { 0 } else { 2 }; - false - } else { - true - } - }); - if v.is_empty() { - conds.map.remove(&key); - } - if conds.map.len() > 1 << 32 { - ret = 0xffff; + Some(std::time::Duration::from_nanos(timeout as u64)) + }; + let waiter = memory.do_wait(location, timeout); + if waiter.is_err() { + // ret is None if there is more than 2^32 waiter in queue or some other error + return Err(Trap::lib(TrapCode::TableAccessOutOfBounds)); } - ret + Ok(waiter.unwrap()) } /// Perform an Atomic.Wait32 @@ -861,11 +836,8 @@ impl Instance { if let Ok(mut ret) = ret { if ret == 0 { - ret = self.do_wait(memory_index.as_u32(), dst, timeout); - } - if ret == 0xffff { - // ret is 0xffff if there is more than 2^32 waiter in queue - return Err(Trap::lib(TrapCode::TableAccessOutOfBounds)); + let memory = self.get_local_vmmemory_mut(memory_index); + ret = Instance::memory_wait(memory, dst, timeout)?; } Ok(ret) } else { @@ -888,14 +860,10 @@ impl Instance { //} let ret = unsafe { memory32_atomic_check32(memory, dst, val) }; - if let Ok(mut ret) = ret { if ret == 0 { - ret = self.do_wait(memory_index.as_u32(), dst, timeout); - } - if ret == 0xffff { - // ret is 0xffff if there is more than 2^32 waiter in queue - return Err(Trap::lib(TrapCode::TableAccessOutOfBounds)); + let memory = self.get_vmmemory_mut(memory_index); + ret = Instance::memory_wait(memory, dst, timeout)?; } Ok(ret) } else { @@ -920,11 +888,8 @@ impl Instance { if let Ok(mut ret) = ret { if ret == 0 { - ret = self.do_wait(memory_index.as_u32(), dst, timeout); - } - if ret == 0xffff { - // ret is 0xffff if there is more than 2^32 waiter in queue - return Err(Trap::lib(TrapCode::TableAccessOutOfBounds)); + let memory = self.get_local_vmmemory_mut(memory_index); + ret = Instance::memory_wait(memory, dst, timeout)?; } Ok(ret) } else { @@ -950,11 +915,8 @@ impl Instance { if let Ok(mut ret) = ret { if ret == 0 { - ret = self.do_wait(memory_index.as_u32(), dst, timeout); - } - if ret == 0xffff { - // ret is 0xffff if there is more than 2^32 waiter in queue - return Err(Trap::lib(TrapCode::TableAccessOutOfBounds)); + let memory = self.get_vmmemory_mut(memory_index); + ret = Instance::memory_wait(memory, dst, timeout)?; } Ok(ret) } else { @@ -962,21 +924,6 @@ impl Instance { } } - fn do_notify(&mut self, key: NotifyLocation, count: u32) -> Result { - let mut conds = self.conditions.lock().unwrap(); - let mut cnt = 0u32; - if let Some(v) = conds.map.get_mut(&key) { - for waiter in v { - if cnt < count { - waiter.notified = true; // mark as was waiked up - waiter.thread.unpark(); // wakeup! - cnt += 1; - } - } - } - Ok(cnt) - } - /// Perform an Atomic.Notify pub(crate) fn local_memory_notify( &mut self, @@ -984,17 +931,10 @@ impl Instance { dst: u32, count: u32, ) -> Result { - //let memory = self.memory(memory_index); - //if ! memory.shared { - // We should trap according to spec, but official test rely on not trapping... - //} - + let memory = self.get_local_vmmemory_mut(memory_index); // fetch the notifier - let key = NotifyLocation { - memory_index: memory_index.as_u32(), - address: dst, - }; - self.do_notify(key, count) + let location = NotifyLocation { address: dst }; + Ok(memory.do_notify(location, count)) } /// Perform an Atomic.Notify @@ -1004,18 +944,10 @@ impl Instance { dst: u32, count: u32, ) -> Result { - //let import = self.imported_memory(memory_index); - //let memory = unsafe { import.definition.as_ref() }; - //if ! memory.shared { - // We should trap according to spec, but official test rely on not trapping... - //} - + let memory = self.get_vmmemory_mut(memory_index); // fetch the notifier - let key = NotifyLocation { - memory_index: memory_index.as_u32(), - address: dst, - }; - self.do_notify(key, count) + let location = NotifyLocation { address: dst }; + Ok(memory.do_notify(location, count)) } } @@ -1125,9 +1057,6 @@ impl VMInstance { funcrefs, imported_funcrefs, vmctx: VMContext {}, - conditions: Arc::new(Mutex::new(NotifyMap { - map: HashMap::new(), - })), }; let mut instance_handle = allocator.into_vminstance(instance); diff --git a/lib/vm/src/lib.rs b/lib/vm/src/lib.rs index b2d69b8616d..b2ef5284bed 100644 --- a/lib/vm/src/lib.rs +++ b/lib/vm/src/lib.rs @@ -32,6 +32,7 @@ mod probestack; mod sig_registry; mod store; mod table; +mod threadconditions; mod trap; mod vmcontext; @@ -47,7 +48,8 @@ pub use crate::imports::Imports; #[allow(deprecated)] pub use crate::instance::{InstanceAllocator, InstanceHandle, VMInstance}; pub use crate::memory::{ - initialize_memory_with_data, LinearMemory, VMMemory, VMOwnedMemory, VMSharedMemory, + initialize_memory_with_data, LinearMemory, NotifyLocation, VMMemory, VMOwnedMemory, + VMSharedMemory, }; pub use crate::mmap::Mmap; pub use crate::probestack::PROBESTACK; diff --git a/lib/vm/src/memory.rs b/lib/vm/src/memory.rs index d3d288a0c04..a55b242f809 100644 --- a/lib/vm/src/memory.rs +++ b/lib/vm/src/memory.rs @@ -5,6 +5,8 @@ //! //! `Memory` is to WebAssembly linear memories what `Table` is to WebAssembly tables. +use crate::threadconditions::ThreadConditions; +pub use crate::threadconditions::{NotifyLocation, WaiterError}; use crate::trap::Trap; use crate::{mmap::Mmap, store::MaybeInstanceOwned, vmcontext::VMMemoryDefinition}; use more_asserts::assert_ge; @@ -13,6 +15,7 @@ use std::convert::TryInto; use std::ptr::NonNull; use std::slice; use std::sync::{Arc, RwLock}; +use std::time::Duration; use wasmer_types::{Bytes, MemoryError, MemoryStyle, MemoryType, Pages}; // The memory mapped area @@ -285,6 +288,7 @@ impl VMOwnedMemory { VMSharedMemory { mmap: Arc::new(RwLock::new(self.mmap)), config: self.config, + conditions: ThreadConditions::new(), } } @@ -346,6 +350,8 @@ pub struct VMSharedMemory { mmap: Arc>, // Configuration of this memory config: VMMemoryConfig, + // waiters list for this memory + conditions: ThreadConditions, } unsafe impl Send for VMSharedMemory {} @@ -381,6 +387,7 @@ impl VMSharedMemory { Ok(Self { mmap: Arc::new(RwLock::new(guard.duplicate()?)), config: self.config.clone(), + conditions: ThreadConditions::new(), }) } } @@ -431,6 +438,20 @@ impl LinearMemory for VMSharedMemory { let forked = Self::duplicate(self)?; Ok(Box::new(forked)) } + + // Add current thread to waiter list + fn do_wait( + &mut self, + dst: NotifyLocation, + timeout: Option, + ) -> Result { + self.conditions.do_wait(dst, timeout) + } + + /// Notify waiters from the wait list. Return the number of waiters notified + fn do_notify(&mut self, dst: NotifyLocation, count: u32) -> u32 { + self.conditions.do_notify(dst, count) + } } impl From for VMMemory { @@ -498,6 +519,20 @@ impl LinearMemory for VMMemory { fn duplicate(&mut self) -> Result, MemoryError> { self.0.duplicate() } + + // Add current thread to waiter list + fn do_wait( + &mut self, + dst: NotifyLocation, + timeout: Option, + ) -> Result { + self.0.do_wait(dst, timeout) + } + + /// Notify waiters from the wait list. Return the number of waiters notified + fn do_notify(&mut self, dst: NotifyLocation, count: u32) -> u32 { + self.0.do_notify(dst, count) + } } impl VMMemory { @@ -616,4 +651,19 @@ where /// Copies this memory to a new memory fn duplicate(&mut self) -> Result, MemoryError>; + + /// Add current thread to the waiter hash, and wait until notified or timout. + /// Return 0 if the waiter has been notified, 2 if the timeout occured, or None if en error happened + fn do_wait( + &mut self, + _dst: NotifyLocation, + _timeout: Option, + ) -> Result { + Err(WaiterError::Unimplemented) + } + + /// Notify waiters from the wait list. Return the number of waiters notified + fn do_notify(&mut self, _dst: NotifyLocation, _count: u32) -> u32 { + 0 + } } diff --git a/lib/vm/src/threadconditions.rs b/lib/vm/src/threadconditions.rs new file mode 100644 index 00000000000..8549453d185 --- /dev/null +++ b/lib/vm/src/threadconditions.rs @@ -0,0 +1,224 @@ +use dashmap::DashMap; +use fnv::FnvBuildHasher; +use std::sync::Arc; +use std::thread::{current, park, park_timeout, Thread}; +use std::time::Duration; +use thiserror::Error; + +/// Wait/Notify error type +#[derive(Debug, Error)] +pub enum WaiterError { + /// Wait/Notify is not implemented for this memory + Unimplemented, + /// To many waiter for an address + TooManyWaiters, +} + +impl std::fmt::Display for WaiterError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "WaiterError") + } +} + +/// A location in memory for a Waiter +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)] +pub struct NotifyLocation { + /// The address of the Waiter location + pub address: u32, +} + +#[derive(Debug)] +struct NotifyWaiter { + pub thread: Thread, + pub notified: bool, +} +#[derive(Debug, Default)] +struct NotifyMap { + pub map: DashMap, FnvBuildHasher>, +} + +/// HashMap of Waiters for the Thread/Notify opcodes +#[derive(Debug)] +pub struct ThreadConditions { + inner: Arc, // The Hasmap with the Notify for the Notify/wait opcodes +} + +impl ThreadConditions { + /// Create a new ThreadConditions + pub fn new() -> Self { + Self { + inner: Arc::new(NotifyMap::default()), + } + } + + // To implement Wait / Notify, a HasMap, behind a mutex, will be used + // to track the address of waiter. The key of the hashmap is based on the memory + // and waiter threads are "park"'d (with or without timeout) + // Notify will wake the waiters by simply "unpark" the thread + // as the Thread info is stored on the HashMap + // once unparked, the waiter thread will remove it's mark on the HashMap + // timeout / awake is tracked with a boolean in the HashMap + // because `park_timeout` doesn't gives any information on why it returns + + /// Add current thread to the waiter hash + pub fn do_wait( + &mut self, + dst: NotifyLocation, + timeout: Option, + ) -> Result { + // fetch the notifier + if self.inner.map.len() >= 1 << 32 { + return Err(WaiterError::TooManyWaiters); + } + self.inner + .map + .entry(dst) + .or_insert_with(Vec::new) + .push(NotifyWaiter { + thread: current(), + notified: false, + }); + if let Some(timeout) = timeout { + park_timeout(timeout); + } else { + park(); + } + let mut bindding = self.inner.map.get_mut(&dst).unwrap(); + let v = bindding.value_mut(); + let id = current().id(); + let mut ret = 0; + v.retain(|cond| { + if cond.thread.id() == id { + ret = if cond.notified { 0 } else { 2 }; + false + } else { + true + } + }); + let empty = v.is_empty(); + drop(bindding); + if empty { + self.inner.map.remove(&dst); + } + Ok(ret) + } + + /// Notify waiters from the wait list + pub fn do_notify(&mut self, dst: NotifyLocation, count: u32) -> u32 { + let mut count_token = 0u32; + if let Some(mut v) = self.inner.map.get_mut(&dst) { + for waiter in v.value_mut() { + if count_token < count && !waiter.notified { + waiter.notified = true; // mark as was waiked up + waiter.thread.unpark(); // wakeup! + count_token += 1; + } + } + } + count_token + } +} + +impl Clone for ThreadConditions { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn threadconditions_notify_nowaiters() { + let mut conditions = ThreadConditions::new(); + let dst = NotifyLocation { address: 0 }; + let ret = conditions.do_notify(dst, 1); + assert_eq!(ret, 0); + } + + #[test] + fn threadconditions_notify_1waiter() { + use std::thread; + + let mut conditions = ThreadConditions::new(); + let mut threadcond = conditions.clone(); + + thread::spawn(move || { + let dst = NotifyLocation { address: 0 }; + let ret = threadcond.do_wait(dst.clone(), None).unwrap(); + assert_eq!(ret, 0); + }); + thread::sleep(Duration::from_millis(10)); + let dst = NotifyLocation { address: 0 }; + let ret = conditions.do_notify(dst, 1); + assert_eq!(ret, 1); + } + + #[test] + fn threadconditions_notify_waiter_timeout() { + use std::thread; + + let mut conditions = ThreadConditions::new(); + let mut threadcond = conditions.clone(); + + thread::spawn(move || { + let dst = NotifyLocation { address: 0 }; + let ret = threadcond + .do_wait(dst.clone(), Some(Duration::from_millis(1))) + .unwrap(); + assert_eq!(ret, 2); + }); + thread::sleep(Duration::from_millis(50)); + let dst = NotifyLocation { address: 0 }; + let ret = conditions.do_notify(dst, 1); + assert_eq!(ret, 0); + } + + #[test] + fn threadconditions_notify_waiter_mismatch() { + use std::thread; + + let mut conditions = ThreadConditions::new(); + let mut threadcond = conditions.clone(); + + thread::spawn(move || { + let dst = NotifyLocation { address: 8 }; + let ret = threadcond + .do_wait(dst.clone(), Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(ret, 2); + }); + thread::sleep(Duration::from_millis(1)); + let dst = NotifyLocation { address: 0 }; + let ret = conditions.do_notify(dst, 1); + assert_eq!(ret, 0); + thread::sleep(Duration::from_millis(100)); + } + + #[test] + fn threadconditions_notify_2waiters() { + use std::thread; + + let mut conditions = ThreadConditions::new(); + let mut threadcond = conditions.clone(); + let mut threadcond2 = conditions.clone(); + + thread::spawn(move || { + let dst = NotifyLocation { address: 0 }; + let ret = threadcond.do_wait(dst.clone(), None).unwrap(); + assert_eq!(ret, 0); + }); + thread::spawn(move || { + let dst = NotifyLocation { address: 0 }; + let ret = threadcond2.do_wait(dst.clone(), None).unwrap(); + assert_eq!(ret, 0); + }); + thread::sleep(Duration::from_millis(20)); + let dst = NotifyLocation { address: 0 }; + let ret = conditions.do_notify(dst, 5); + assert_eq!(ret, 2); + } +}