Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Wait/Notify opcode, the waiters hashmap is now on the Memory itself #3723

Merged
merged 17 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions lib/vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ lazy_static = "1.4.0"
region = { version = "3.0" }
corosensei = { version = "0.1.2" }
derivative = { version = "^2" }
dashmap = { version = "5.4" }
fnv = "1.0.3"
# - Optional shared dependencies.
tracing = { version = "0.1", optional = true }

Expand Down
212 changes: 85 additions & 127 deletions lib/vm/src/instance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use crate::vmcontext::{
VMFunctionImport, VMFunctionKind, VMGlobalDefinition, VMGlobalImport, VMMemoryDefinition,
VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, VMTableImport, VMTrampoline,
};
use crate::LinearMemory;
use crate::{FunctionBodyPtr, MaybeInstanceOwned, TrapHandlerFn, VMFunctionBody};
use crate::{LinearMemory, NotifyLocation};
use crate::{VMFuncRef, VMFunction, VMGlobal, VMMemory, VMTable};
pub use allocator::InstanceAllocator;
use memoffset::offset_of;
Expand All @@ -33,29 +33,14 @@ use std::fmt;
use std::mem;
use std::ptr::{self, NonNull};
use std::slice;
use std::sync::{Arc, Mutex};
use std::thread::{current, park, park_timeout, Thread};
use std::sync::Arc;
use wasmer_types::entity::{packed_option::ReservedValue, BoxedSlice, EntityRef, PrimaryMap};
use wasmer_types::{
DataIndex, DataInitializer, ElemIndex, ExportIndex, FunctionIndex, GlobalIndex, GlobalInit,
LocalFunctionIndex, LocalGlobalIndex, LocalMemoryIndex, LocalTableIndex, MemoryError,
MemoryIndex, ModuleInfo, Pages, SignatureIndex, TableIndex, TableInitializer, VMOffsets,
};

#[derive(Hash, Eq, PartialEq, Clone, Copy)]
struct NotifyLocation {
memory_index: u32,
address: u32,
}

struct NotifyWaiter {
thread: Thread,
notified: bool,
}
struct NotifyMap {
map: HashMap<NotifyLocation, Vec<NotifyWaiter>>,
}

/// A WebAssembly instance.
///
/// The type is dynamically-sized. Indeed, the `vmctx` field can
Expand Down Expand Up @@ -105,9 +90,6 @@ pub(crate) struct Instance {
/// will point to elements here for functions imported by this instance.
imported_funcrefs: BoxedSlice<FunctionIndex, NonNull<VMCallerCheckedAnyfunc>>,

/// The Hasmap with the Notify for the Notify/wait opcodes
conditions: Arc<Mutex<NotifyMap>>,

/// Additional context used by compiled WebAssembly code. This
/// field is last, and represents a dynamically-sized array that
/// extends beyond the nominal end of the struct (similar to a
Expand Down Expand Up @@ -276,6 +258,31 @@ impl Instance {
}
}

/// Get a locally defined or imported memory.
fn get_vmmemory_mut(&mut self, index: MemoryIndex) -> &mut VMMemory {
if let Some(local_index) = self.module.local_memory_index(index) {
unsafe {
self.memories
.get_mut(local_index)
.unwrap()
.get_mut(self.context.as_mut().unwrap())
}
} else {
let import = self.imported_memory(index);
unsafe { import.handle.get_mut(self.context.as_mut().unwrap()) }
}
}

/// Get a locally defined memory as mutable.
fn get_local_vmmemory_mut(&mut self, local_index: LocalMemoryIndex) -> &mut VMMemory {
unsafe {
self.memories
.get_mut(local_index)
.unwrap()
.get_mut(self.context.as_mut().unwrap())
}
}

/// Return the indexed `VMGlobalDefinition`.
fn global(&self, index: LocalGlobalIndex) -> VMGlobalDefinition {
unsafe { self.global_ptr(index).as_ref().clone() }
Expand Down Expand Up @@ -797,53 +804,6 @@ impl Instance {
}
}

// To implement Wait / Notify, a HasMap, behind a mutex, will be used
// to track the address of waiter. The key of the hashmap is based on the memory
// and waiter threads are "park"'d (with or without timeout)
// Notify will wake the waiters by simply "unpark" the thread
// as the Thread info is stored on the HashMap
// once unparked, the waiter thread will remove it's mark on the HashMap
// timeout / awake is tracked with a boolean in the HashMap
// because `park_timeout` doesn't gives any information on why it returns
fn do_wait(&mut self, index: u32, dst: u32, timeout: i64) -> u32 {
// fetch the notifier
let key = NotifyLocation {
memory_index: index,
address: dst,
};
let mut conds = self.conditions.lock().unwrap();
let v = conds.map.entry(key).or_insert_with(Vec::new);
v.push(NotifyWaiter {
thread: current(),
notified: false,
});
drop(conds);
if timeout < 0 {
park();
} else {
park_timeout(std::time::Duration::from_nanos(timeout as u64));
}
let mut conds = self.conditions.lock().unwrap();
let v = conds.map.get_mut(&key).unwrap();
let id = current().id();
let mut ret = 0;
v.retain(|cond| {
if cond.thread.id() == id {
ret = if cond.notified { 0 } else { 2 };
false
} else {
true
}
});
if v.is_empty() {
conds.map.remove(&key);
}
if conds.map.len() > 1 << 32 {
ret = 0xffff;
}
ret
}

/// Perform an Atomic.Wait32
pub(crate) fn local_memory_wait32(
&mut self,
Expand All @@ -861,11 +821,19 @@ impl Instance {

if let Ok(mut ret) = ret {
if ret == 0 {
ret = self.do_wait(memory_index.as_u32(), dst, timeout);
}
if ret == 0xffff {
// ret is 0xffff if there is more than 2^32 waiter in queue
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
let memory = self.get_local_vmmemory_mut(memory_index);
let location = NotifyLocation { address: dst };
let timeout = if timeout < 0 {
None
} else {
Some(std::time::Duration::from_nanos(timeout as u64))
};
let waiter = memory.do_wait(location, timeout);
if waiter.is_none() {
// ret is None if there is more than 2^32 waiter in queue or some other error
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
theduke marked this conversation as resolved.
Show resolved Hide resolved
}
ret = waiter.unwrap();
}
Ok(ret)
} else {
Expand All @@ -888,14 +856,21 @@ impl Instance {
//}

let ret = unsafe { memory32_atomic_check32(memory, dst, val) };

if let Ok(mut ret) = ret {
if ret == 0 {
ret = self.do_wait(memory_index.as_u32(), dst, timeout);
}
if ret == 0xffff {
// ret is 0xffff if there is more than 2^32 waiter in queue
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
let memory = self.get_vmmemory_mut(memory_index);
let location = NotifyLocation { address: dst };
ptitSeb marked this conversation as resolved.
Show resolved Hide resolved
let timeout = if timeout < 0 {
None
} else {
Some(std::time::Duration::from_nanos(timeout as u64))
};
let waiter = memory.do_wait(location, timeout);
if waiter.is_none() {
// ret is None if there is more than 2^32 waiter in queue or some other error
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
}
ret = waiter.unwrap();
}
Ok(ret)
} else {
Expand All @@ -920,11 +895,19 @@ impl Instance {

if let Ok(mut ret) = ret {
if ret == 0 {
ret = self.do_wait(memory_index.as_u32(), dst, timeout);
}
if ret == 0xffff {
// ret is 0xffff if there is more than 2^32 waiter in queue
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
let memory = self.get_local_vmmemory_mut(memory_index);
let location = NotifyLocation { address: dst };
let timeout = if timeout < 0 {
ptitSeb marked this conversation as resolved.
Show resolved Hide resolved
None
} else {
Some(std::time::Duration::from_nanos(timeout as u64))
};
let waiter = memory.do_wait(location, timeout);
if waiter.is_none() {
// ret is None if there is more than 2^32 waiter in queue or some other error
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
}
ret = waiter.unwrap();
}
Ok(ret)
} else {
Expand All @@ -950,51 +933,37 @@ impl Instance {

if let Ok(mut ret) = ret {
if ret == 0 {
ret = self.do_wait(memory_index.as_u32(), dst, timeout);
}
if ret == 0xffff {
// ret is 0xffff if there is more than 2^32 waiter in queue
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
let memory = self.get_vmmemory_mut(memory_index);
let location = NotifyLocation { address: dst };
let timeout = if timeout < 0 {
ptitSeb marked this conversation as resolved.
Show resolved Hide resolved
None
} else {
Some(std::time::Duration::from_nanos(timeout as u64))
};
let waiter = memory.do_wait(location, timeout);
if waiter.is_none() {
// ret is None if there is more than 2^32 waiter in queue or some other error
return Err(Trap::lib(TrapCode::TableAccessOutOfBounds));
}
ret = waiter.unwrap();
}
Ok(ret)
} else {
ret
}
}

fn do_notify(&mut self, key: NotifyLocation, count: u32) -> Result<u32, Trap> {
let mut conds = self.conditions.lock().unwrap();
let mut cnt = 0u32;
if let Some(v) = conds.map.get_mut(&key) {
for waiter in v {
if cnt < count {
waiter.notified = true; // mark as was waiked up
waiter.thread.unpark(); // wakeup!
cnt += 1;
}
}
}
Ok(cnt)
}

/// Perform an Atomic.Notify
pub(crate) fn local_memory_notify(
&mut self,
memory_index: LocalMemoryIndex,
dst: u32,
count: u32,
) -> Result<u32, Trap> {
//let memory = self.memory(memory_index);
//if ! memory.shared {
// We should trap according to spec, but official test rely on not trapping...
//}

let memory = self.get_local_vmmemory_mut(memory_index);
// fetch the notifier
let key = NotifyLocation {
memory_index: memory_index.as_u32(),
address: dst,
};
self.do_notify(key, count)
let location = NotifyLocation { address: dst };
Ok(memory.do_notify(location, count))
}

/// Perform an Atomic.Notify
Expand All @@ -1004,18 +973,10 @@ impl Instance {
dst: u32,
count: u32,
) -> Result<u32, Trap> {
//let import = self.imported_memory(memory_index);
//let memory = unsafe { import.definition.as_ref() };
//if ! memory.shared {
// We should trap according to spec, but official test rely on not trapping...
//}

let memory = self.get_vmmemory_mut(memory_index);
// fetch the notifier
let key = NotifyLocation {
memory_index: memory_index.as_u32(),
address: dst,
};
self.do_notify(key, count)
let location = NotifyLocation { address: dst };
Ok(memory.do_notify(location, count))
}
}

Expand Down Expand Up @@ -1125,9 +1086,6 @@ impl VMInstance {
funcrefs,
imported_funcrefs,
vmctx: VMContext {},
conditions: Arc::new(Mutex::new(NotifyMap {
map: HashMap::new(),
})),
};

let mut instance_handle = allocator.into_vminstance(instance);
Expand Down
4 changes: 3 additions & 1 deletion lib/vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod probestack;
mod sig_registry;
mod store;
mod table;
mod threadconditions;
mod trap;
mod vmcontext;

Expand All @@ -47,7 +48,8 @@ pub use crate::imports::Imports;
#[allow(deprecated)]
pub use crate::instance::{InstanceAllocator, InstanceHandle, VMInstance};
pub use crate::memory::{
initialize_memory_with_data, LinearMemory, VMMemory, VMOwnedMemory, VMSharedMemory,
initialize_memory_with_data, LinearMemory, NotifyLocation, VMMemory, VMOwnedMemory,
VMSharedMemory,
};
pub use crate::mmap::Mmap;
pub use crate::probestack::PROBESTACK;
Expand Down
Loading