diff --git a/lib/wasi/src/os/task/control_plane.rs b/lib/wasi/src/os/task/control_plane.rs index cdb6e7545ac..47b93252bed 100644 --- a/lib/wasi/src/os/task/control_plane.rs +++ b/lib/wasi/src/os/task/control_plane.rs @@ -13,6 +13,28 @@ pub struct WasiControlPlane { state: Arc, } +#[derive(Debug, Clone)] +pub struct WasiControlPlaneHandle { + inner: std::sync::Weak, +} + +impl WasiControlPlaneHandle { + fn new(inner: &Arc) -> Self { + Self { + inner: Arc::downgrade(inner), + } + } + + pub fn upgrade(&self) -> Option { + self.inner.upgrade().map(|state| WasiControlPlane { state }) + } + + pub fn must_upgrade(&self) -> WasiControlPlane { + let state = self.inner.upgrade().expect("control plane unavailable"); + WasiControlPlane { state } + } +} + #[derive(Debug, Clone)] pub struct ControlPlaneConfig { /// Total number of tasks (processes + threads) that can be spawned. @@ -67,6 +89,10 @@ impl WasiControlPlane { } } + pub fn handle(&self) -> WasiControlPlaneHandle { + WasiControlPlaneHandle::new(&self.state) + } + /// Get the current count of active tasks (threads). fn active_task_count(&self) -> usize { self.state.task_count.load(Ordering::SeqCst) @@ -99,7 +125,7 @@ impl WasiControlPlane { } // Create the process first to do all the allocations before locking. - let mut proc = WasiProcess::new(WasiProcessId::from(0), self.clone()); + let mut proc = WasiProcess::new(WasiProcessId::from(0), self.handle()); let mut mutable = self.state.mutable.write().unwrap(); diff --git a/lib/wasi/src/os/task/process.rs b/lib/wasi/src/os/task/process.rs index 1d6a016b142..3be8e8be8a7 100644 --- a/lib/wasi/src/os/task/process.rs +++ b/lib/wasi/src/os/task/process.rs @@ -17,13 +17,12 @@ use wasmer_wasi_types::{ }; use crate::{ - os::task::{control_plane::WasiControlPlane, signal::WasiSignalInterval}, - syscalls::platform_clock_time_get, - WasiThread, WasiThreadHandle, WasiThreadId, + os::task::signal::WasiSignalInterval, syscalls::platform_clock_time_get, WasiThread, + WasiThreadHandle, WasiThreadId, }; use super::{ - control_plane::ControlPlaneError, + control_plane::{ControlPlaneError, WasiControlPlaneHandle}, signal::{SignalDeliveryError, SignalHandlerAbi}, task_join_handle::{OwnedTaskStatus, TaskJoinHandle}, }; @@ -81,7 +80,7 @@ pub struct WasiProcess { /// Reference back to the compute engine // TODO: remove this reference, access should happen via separate state instead // (we don't want cyclical references) - pub(crate) compute: WasiControlPlane, + pub(crate) compute: WasiControlPlaneHandle, /// Reference to the exit code for the main thread pub(crate) finished: Arc, /// List of all the children spawned from this thread @@ -134,11 +133,11 @@ impl Drop for WasiProcessWait { } impl WasiProcess { - pub fn new(pid: WasiProcessId, compute: WasiControlPlane) -> Self { + pub fn new(pid: WasiProcessId, plane: WasiControlPlaneHandle) -> Self { WasiProcess { pid, ppid: 0u32.into(), - compute, + compute: plane, inner: Arc::new(RwLock::new(WasiProcessInner { threads: Default::default(), thread_count: Default::default(), @@ -184,7 +183,7 @@ impl WasiProcess { /// Creates a a thread and returns it pub fn new_thread(&self) -> Result { - let task_count_guard = self.compute.register_task()?; + let task_count_guard = self.compute.must_upgrade().register_task()?; let mut inner = self.inner.write().unwrap(); let id = inner.thread_seed.inc(); @@ -232,7 +231,7 @@ impl WasiProcess { if self.waiting.load(Ordering::Acquire) > 0 { let mut triggered = false; for pid in children.iter() { - if let Some(process) = self.compute.get_process(*pid) { + if let Some(process) = self.compute.must_upgrade().get_process(*pid) { process.signal_process(signal); triggered = true; } @@ -301,7 +300,7 @@ impl WasiProcess { } let mut waits = Vec::new(); for pid in children { - if let Some(process) = self.compute.get_process(pid) { + if let Some(process) = self.compute.must_upgrade().get_process(pid) { let children = self.children.clone(); waits.push(async move { let join = process.join().await; @@ -330,7 +329,7 @@ impl WasiProcess { let mut waits = Vec::new(); for pid in children { - if let Some(process) = self.compute.get_process(pid) { + if let Some(process) = self.compute.must_upgrade().get_process(pid) { let children = self.children.clone(); waits.push(async move { let join = process.join().await; @@ -358,11 +357,6 @@ impl WasiProcess { thread.set_status_finished(Ok(exit_code)) } } - - /// Gains access to the compute control plane - pub fn control_plane(&self) -> &WasiControlPlane { - &self.compute - } } impl SignalHandlerAbi for WasiProcess { diff --git a/lib/wasi/src/state/env.rs b/lib/wasi/src/state/env.rs index 52fbc8830a3..670339095ac 100644 --- a/lib/wasi/src/state/env.rs +++ b/lib/wasi/src/state/env.rs @@ -260,6 +260,7 @@ impl WasiEnvInit { /// The environment provided to the WASI imports. #[derive(Debug)] pub struct WasiEnv { + pub control_plane: WasiControlPlane, /// Represents the process this environment is attached to pub process: WasiProcess, /// Represents the thread this environment is attached to @@ -312,6 +313,7 @@ impl WasiEnv { // Currently only used by fork/spawn related syscalls. pub(crate) fn duplicate(&self) -> Self { Self { + control_plane: self.control_plane.clone(), process: self.process.clone(), poll_seed: self.poll_seed, thread: self.thread.clone(), @@ -330,7 +332,7 @@ impl WasiEnv { /// Forking the WasiState is used when either fork or vfork is called pub fn fork(&self) -> Result<(Self, WasiThreadHandle), ControlPlaneError> { - let process = self.process.compute.new_process()?; + let process = self.control_plane.new_process()?; let handle = process.new_thread()?; let thread = handle.as_thread(); @@ -341,6 +343,7 @@ impl WasiEnv { let bin_factory = self.bin_factory.clone(); let new_env = Self { + control_plane: self.control_plane.clone(), process, thread, vfork: None, @@ -379,6 +382,7 @@ impl WasiEnv { }; let mut env = Self { + control_plane: init.control_plane, process, thread: thread.as_thread(), vfork: None, diff --git a/lib/wasi/src/syscalls/wasix/proc_join.rs b/lib/wasi/src/syscalls/wasix/proc_join.rs index f8bae8fc40f..4dc8c5dcf59 100644 --- a/lib/wasi/src/syscalls/wasix/proc_join.rs +++ b/lib/wasi/src/syscalls/wasix/proc_join.rs @@ -76,7 +76,7 @@ pub fn proc_join( // Otherwise we wait for the specific PID let env = ctx.data(); let pid: WasiProcessId = pid.into(); - let process = env.process.control_plane().get_process(pid); + let process = env.control_plane.get_process(pid); if let Some(process) = process { let exit_code = wasi_try_ok!(__asyncify(&mut ctx, None, async move { let code = process.join().await.unwrap_or(Errno::Child as u32); diff --git a/lib/wasi/src/syscalls/wasix/proc_parent.rs b/lib/wasi/src/syscalls/wasix/proc_parent.rs index 908a260dec4..032dc1228cf 100644 --- a/lib/wasi/src/syscalls/wasix/proc_parent.rs +++ b/lib/wasi/src/syscalls/wasix/proc_parent.rs @@ -15,14 +15,12 @@ pub fn proc_parent( if pid == env.process.pid() { let memory = env.memory_view(&ctx); wasi_try_mem!(ret_parent.write(&memory, env.process.ppid().raw() as Pid)); + Errno::Success + } else if let Some(process) = env.control_plane.get_process(pid) { + let memory = env.memory_view(&ctx); + wasi_try_mem!(ret_parent.write(&memory, process.pid().raw() as Pid)); + Errno::Success } else { - let control_plane = env.process.control_plane(); - if let Some(process) = control_plane.get_process(pid) { - let memory = env.memory_view(&ctx); - wasi_try_mem!(ret_parent.write(&memory, process.pid().raw() as Pid)); - } else { - return Errno::Badf; - } + Errno::Badf } - Errno::Success } diff --git a/lib/wasi/src/syscalls/wasix/proc_signal.rs b/lib/wasi/src/syscalls/wasix/proc_signal.rs index 7f5f695c447..3e2d1c83a77 100644 --- a/lib/wasi/src/syscalls/wasix/proc_signal.rs +++ b/lib/wasi/src/syscalls/wasix/proc_signal.rs @@ -23,7 +23,7 @@ pub fn proc_signal( let process = { let pid: WasiProcessId = pid.into(); - ctx.data().process.compute.get_process(pid) + ctx.data().control_plane.get_process(pid) }; if let Some(process) = process { process.signal_process(sig); diff --git a/lib/wasi/src/syscalls/wasix/proc_spawn.rs b/lib/wasi/src/syscalls/wasix/proc_spawn.rs index 84bfa4d0b0a..dfaecaf3787 100644 --- a/lib/wasi/src/syscalls/wasix/proc_spawn.rs +++ b/lib/wasi/src/syscalls/wasix/proc_spawn.rs @@ -39,7 +39,7 @@ pub fn proc_spawn( ret_handles: WasmPtr, ) -> Result { let env = ctx.data(); - let control_plane = env.process.control_plane(); + let control_plane = &env.control_plane; let memory = env.memory_view(&ctx); let name = unsafe { get_input_str_bus_ok!(&memory, name, name_len) }; let args = unsafe { get_input_str_bus_ok!(&memory, args, args_len) };