From 11af6aff01313ed1ba0a154507e5609749f8e5e1 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 16 Aug 2023 16:14:26 -0700 Subject: [PATCH] wasi-nn: update upstream specification (#6853) This change removes the temporary `*.wit` files and bumps the `spec` directory to the latest [wasi-nn] commit. This is now possible because the upstream `spec` repository has all of the updated WIT and WITX bits. The `load_by_name` implementations are left as TODOs for now and will be included in a subsequent PR. One other change is a refactoring: we wrap up some types--`Graph` and `ExecutionContext`--to avoid passing around `Box`. This simplifies some of the code but should not change behavior. Run all tests: prtest:full [wasi-nn]: https://github.com/WebAssembly/wasi-nn. --- crates/wasi-nn/spec | 2 +- crates/wasi-nn/src/backend/mod.rs | 21 +++--- crates/wasi-nn/src/backend/openvino.rs | 22 +++---- crates/wasi-nn/src/ctx.rs | 34 +++++----- crates/wasi-nn/src/lib.rs | 38 +++++++++++ crates/wasi-nn/src/wit.rs | 57 +++++++++++------ crates/wasi-nn/src/witx.rs | 24 +++++-- crates/wasi-nn/wit/inference.wit | 24 ------- crates/wasi-nn/wit/types.wit | 88 -------------------------- crates/wasi-nn/wit/world.wit | 20 ------ src/commands/run.rs | 2 +- 11 files changed, 131 insertions(+), 201 deletions(-) delete mode 100644 crates/wasi-nn/wit/inference.wit delete mode 100644 crates/wasi-nn/wit/types.wit delete mode 100644 crates/wasi-nn/wit/world.wit diff --git a/crates/wasi-nn/spec b/crates/wasi-nn/spec index 8adc5b9b3bb8..c1f8b87e923a 160000 --- a/crates/wasi-nn/spec +++ b/crates/wasi-nn/spec @@ -1 +1 @@ -Subproject commit 8adc5b9b3bb8f885d44f55b464718e24af892c94 +Subproject commit c1f8b87e923aedda02964c31b0e1d37e331ec402 diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index 19b6610f1581..ad929a9a74e0 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -6,33 +6,30 @@ mod openvino; use self::openvino::OpenvinoBackend; use crate::wit::types::{ExecutionTarget, Tensor}; +use crate::{ExecutionContext, Graph}; use thiserror::Error; use wiggle::GuestError; /// Return a list of all available backend frameworks. -pub(crate) fn list() -> Vec<(BackendKind, Box)> { +pub fn list() -> Vec<(BackendKind, Box)> { vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))] } /// A [Backend] contains the necessary state to load [BackendGraph]s. -pub(crate) trait Backend: Send + Sync { +pub trait Backend: Send + Sync { fn name(&self) -> &str; - fn load( - &mut self, - builders: &[&[u8]], - target: ExecutionTarget, - ) -> Result, BackendError>; + fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result; } /// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing /// implementation for a [crate::witx::types::Graph]. -pub(crate) trait BackendGraph: Send + Sync { - fn init_execution_context(&mut self) -> Result, BackendError>; +pub trait BackendGraph: Send + Sync { + fn init_execution_context(&mut self) -> Result; } /// A [BackendExecutionContext] performs the actual inference; this is the /// backing implementation for a [crate::witx::types::GraphExecutionContext]. -pub(crate) trait BackendExecutionContext: Send + Sync { +pub trait BackendExecutionContext: Send + Sync { fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>; fn compute(&mut self) -> Result<(), BackendError>; fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result; @@ -52,7 +49,7 @@ pub enum BackendError { NotEnoughMemory(usize), } -#[derive(Hash, PartialEq, Eq, Clone, Copy)] -pub(crate) enum BackendKind { +#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)] +pub enum BackendKind { OpenVINO, } diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index d44236250760..93f51771c95f 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -2,6 +2,7 @@ use super::{Backend, BackendError, BackendExecutionContext, BackendGraph}; use crate::wit::types::{ExecutionTarget, Tensor, TensorType}; +use crate::{ExecutionContext, Graph}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use std::sync::Arc; @@ -15,11 +16,7 @@ impl Backend for OpenvinoBackend { "openvino" } - fn load( - &mut self, - builders: &[&[u8]], - target: ExecutionTarget, - ) -> Result, BackendError> { + fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result { if builders.len() != 2 { return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()).into()); } @@ -54,8 +51,9 @@ impl Backend for OpenvinoBackend { let exec_network = core.load_network(&cnn_network, map_execution_target_to_string(target))?; - - Ok(Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network))) + let box_: Box = + Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network)); + Ok(box_.into()) } } @@ -65,12 +63,11 @@ unsafe impl Send for OpenvinoGraph {} unsafe impl Sync for OpenvinoGraph {} impl BackendGraph for OpenvinoGraph { - fn init_execution_context(&mut self) -> Result, BackendError> { + fn init_execution_context(&mut self) -> Result { let infer_request = self.1.create_infer_request()?; - Ok(Box::new(OpenvinoExecutionContext( - self.0.clone(), - infer_request, - ))) + let box_: Box = + Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request)); + Ok(box_.into()) } } @@ -145,5 +142,6 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision TensorType::Fp32 => Precision::FP32, TensorType::U8 => Precision::U8, TensorType::I32 => Precision::I32, + TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"), } } diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index e2cea15f9654..e961938a8f3d 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -1,36 +1,36 @@ //! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. -use crate::backend::{ - self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind, -}; +use crate::backend::{self, Backend, BackendError, BackendKind}; use crate::wit::types::GraphEncoding; -use std::collections::HashMap; -use std::hash::Hash; +use crate::{ExecutionContext, Graph}; +use std::{collections::HashMap, hash::Hash}; use thiserror::Error; use wiggle::GuestError; +type Backends = HashMap>; type GraphId = u32; type GraphExecutionContextId = u32; /// Capture the state necessary for calling into the backend ML libraries. pub struct WasiNnCtx { - pub(crate) backends: HashMap>, - pub(crate) graphs: Table>, - pub(crate) executions: Table>, + pub(crate) backends: Backends, + pub(crate) graphs: Table, + pub(crate) executions: Table, } impl WasiNnCtx { /// Make a new context from the default state. - pub fn new() -> WasiNnResult { - let mut backends = HashMap::new(); - for (kind, backend) in backend::list() { - backends.insert(kind, backend); - } - Ok(Self { + pub fn new(backends: Backends) -> Self { + Self { backends, graphs: Table::default(), executions: Table::default(), - }) + } + } +} +impl Default for WasiNnCtx { + fn default() -> Self { + WasiNnCtx::new(backend::list().into_iter().collect()) } } @@ -59,6 +59,8 @@ pub enum UsageError { InvalidExecutionContextHandle, #[error("Not enough memory to copy tensor data of size: {0}")] NotEnoughMemory(u32), + #[error("No graph found with name: {0}")] + NotFound(String), } pub(crate) type WasiNnResult = std::result::Result; @@ -105,6 +107,6 @@ mod test { #[test] fn instantiate() { - WasiNnCtx::new().unwrap(); + WasiNnCtx::default(); } } diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 2cf8d6e8e56b..1abd6c0b1372 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -4,3 +4,41 @@ mod ctx; pub use ctx::WasiNnCtx; pub mod wit; pub mod witx; + +/// A backend-defined graph (i.e., ML model). +pub struct Graph(Box); +impl From> for Graph { + fn from(value: Box) -> Self { + Self(value) + } +} +impl std::ops::Deref for Graph { + type Target = dyn backend::BackendGraph; + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} +impl std::ops::DerefMut for Graph { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut() + } +} + +/// A backend-defined execution context. +pub struct ExecutionContext(Box); +impl From> for ExecutionContext { + fn from(value: Box) -> Self { + Self(value) + } +} +impl std::ops::Deref for ExecutionContext { + type Target = dyn backend::BackendExecutionContext; + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} +impl std::ops::DerefMut for ExecutionContext { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut() + } +} diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index e5a823bde939..b63a22025cbc 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -17,23 +17,29 @@ use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx}; -pub use gen::types; -pub use gen_::Ml as ML; - /// Generate the traits and types from the `wasi-nn` WIT specification. mod gen_ { - wasmtime::component::bindgen!("ml"); + wasmtime::component::bindgen!("ml" in "spec/wit/wasi-nn.wit"); } use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need. -impl gen::inference::Host for WasiNnCtx { +// Export the `types` used in this crate as well as `ML::add_to_linker`. +pub mod types { + use super::gen; + pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding}; + pub use gen::inference::GraphExecutionContext; + pub use gen::tensor::{Tensor, TensorType}; +} +pub use gen_::Ml as ML; + +impl gen::graph::Host for WasiNnCtx { /// Load an opaque sequence of bytes to use for inference. fn load( &mut self, - builders: gen::types::GraphBuilderArray, - encoding: gen::types::GraphEncoding, - target: gen::types::ExecutionTarget, - ) -> wasmtime::Result> { + builders: Vec, + encoding: gen::graph::GraphEncoding, + target: gen::graph::ExecutionTarget, + ) -> wasmtime::Result> { let backend_kind: BackendKind = encoding.try_into()?; let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); @@ -45,13 +51,22 @@ impl gen::inference::Host for WasiNnCtx { Ok(Ok(graph_id)) } + fn load_by_name( + &mut self, + _name: String, + ) -> wasmtime::Result> { + todo!() + } +} + +impl gen::inference::Host for WasiNnCtx { /// Create an execution instance of a loaded graph. /// /// TODO: remove completely? fn init_execution_context( &mut self, - graph_id: gen::types::Graph, - ) -> wasmtime::Result> { + graph_id: gen::graph::Graph, + ) -> wasmtime::Result> { let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { graph.init_execution_context()? } else { @@ -65,10 +80,10 @@ impl gen::inference::Host for WasiNnCtx { /// Define the inputs to use for inference. fn set_input( &mut self, - exec_context_id: gen::types::GraphExecutionContext, + exec_context_id: gen::inference::GraphExecutionContext, index: u32, - tensor: gen::types::Tensor, - ) -> wasmtime::Result> { + tensor: gen::tensor::Tensor, + ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { exec_context.set_input(index, &tensor)?; Ok(Ok(())) @@ -82,8 +97,8 @@ impl gen::inference::Host for WasiNnCtx { /// TODO: refactor to compute(list) -> result, error> fn compute( &mut self, - exec_context_id: gen::types::GraphExecutionContext, - ) -> wasmtime::Result> { + exec_context_id: gen::inference::GraphExecutionContext, + ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { exec_context.compute()?; Ok(Ok(())) @@ -95,9 +110,9 @@ impl gen::inference::Host for WasiNnCtx { /// Extract the outputs after inference. fn get_output( &mut self, - exec_context_id: gen::types::GraphExecutionContext, + exec_context_id: gen::inference::GraphExecutionContext, index: u32, - ) -> wasmtime::Result> { + ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { // Read the output bytes. TODO: this involves a hard-coded upper // limit on the tensor size that is necessary because there is no @@ -113,11 +128,11 @@ impl gen::inference::Host for WasiNnCtx { } } -impl TryFrom for crate::backend::BackendKind { +impl TryFrom for crate::backend::BackendKind { type Error = UsageError; - fn try_from(value: gen::types::GraphEncoding) -> Result { + fn try_from(value: gen::graph::GraphEncoding) -> Result { match value { - gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), + gen::graph::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), _ => Err(UsageError::InvalidEncoding(value.into())), } } diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index 7371017d0424..b339d9a8d389 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -13,7 +13,6 @@ //! //! [`types`]: crate::wit::types -use crate::backend::BackendKind; use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; use wiggle::GuestPtr; @@ -23,7 +22,7 @@ pub use gen::wasi_ephemeral_nn::add_to_linker; mod gen { use super::*; wiggle::from_witx!({ - witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], + witx: ["$WASI_ROOT/wasi-nn.witx"], errors: { nn_errno => WasiNnError } }); @@ -59,7 +58,7 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { encoding: gen::types::GraphEncoding, target: gen::types::ExecutionTarget, ) -> Result { - let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) { + let graph = if let Some(backend) = self.backends.get_mut(&encoding.try_into()?) { // Retrieve all of the "builder lists" from the Wasm memory (see // $graph_builder_array) as slices for a backend to operate on. let mut slices = vec![]; @@ -79,6 +78,10 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { Ok(graph_id.into()) } + fn load_by_name<'b>(&mut self, _name: &wiggle::GuestPtr<'b, str>) -> Result { + todo!() + } + fn init_execution_context( &mut self, graph_id: gen::types::Graph, @@ -140,10 +143,12 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { // Implement some conversion from `witx::types::*` to this crate's version. -impl From for BackendKind { - fn from(value: gen::types::GraphEncoding) -> Self { +impl TryFrom for crate::backend::BackendKind { + type Error = UsageError; + fn try_from(value: gen::types::GraphEncoding) -> std::result::Result { match value { - gen::types::GraphEncoding::Openvino => BackendKind::OpenVINO, + gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), + _ => Err(UsageError::InvalidEncoding(value.into())), } } } @@ -160,6 +165,13 @@ impl From for crate::wit::types::GraphEncoding { fn from(value: gen::types::GraphEncoding) -> Self { match value { gen::types::GraphEncoding::Openvino => crate::wit::types::GraphEncoding::Openvino, + gen::types::GraphEncoding::Onnx => crate::wit::types::GraphEncoding::Onnx, + gen::types::GraphEncoding::Tensorflow => crate::wit::types::GraphEncoding::Tensorflow, + gen::types::GraphEncoding::Pytorch => crate::wit::types::GraphEncoding::Pytorch, + gen::types::GraphEncoding::Tensorflowlite => { + crate::wit::types::GraphEncoding::Tensorflowlite + } + gen::types::GraphEncoding::Autodetect => todo!("autodetect not supported"), } } } diff --git a/crates/wasi-nn/wit/inference.wit b/crates/wasi-nn/wit/inference.wit deleted file mode 100644 index df754231f696..000000000000 --- a/crates/wasi-nn/wit/inference.wit +++ /dev/null @@ -1,24 +0,0 @@ -interface inference { - use types.{graph-builder-array, graph-encoding, execution-target, graph, - tensor, tensor-data, error, graph-execution-context} - - /// Load an opaque sequence of bytes to use for inference. - load: func(builder: graph-builder-array, encoding: graph-encoding, - target: execution-target) -> result - - /// Create an execution instance of a loaded graph. - /// - /// TODO: remove completely? - init-execution-context: func(graph: graph) -> result - - /// Define the inputs to use for inference. - set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error> - - /// Compute the inference on the given inputs. - /// - /// TODO: refactor to compute(list) -> result, error> - compute: func(ctx: graph-execution-context) -> result<_, error> - - /// Extract the outputs after inference. - get-output: func(ctx: graph-execution-context, index: u32) -> result -} diff --git a/crates/wasi-nn/wit/types.wit b/crates/wasi-nn/wit/types.wit deleted file mode 100644 index f134b730e1a8..000000000000 --- a/crates/wasi-nn/wit/types.wit +++ /dev/null @@ -1,88 +0,0 @@ -interface types { - /// The dimensions of a tensor. - /// - /// The array length matches the tensor rank and each element in the array - /// describes the size of each dimension. - type tensor-dimensions = list - - /// The type of the elements in a tensor. - enum tensor-type { - FP16, - FP32, - U8, - I32 - } - - /// The tensor data. - /// - /// Initially conceived as a sparse representation, each empty cell would be filled with zeros and - /// the array length must match the product of all of the dimensions and the number of bytes in the - /// type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length 16). - /// Naturally, this representation requires some knowledge of how to lay out data in memory--e.g., - /// using row-major ordering--and could perhaps be improved. - type tensor-data = list - - /// A tensor. - record tensor { - /// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor - /// containing a single value, use `[1]` for the tensor dimensions. - dimensions: tensor-dimensions, - - /// Describe the type of element in the tensor (e.g., f32). - tensor-type: tensor-type, - - /// Contains the tensor data. - data: tensor-data, - } - - /// The graph initialization data. - // - /// This consists of an array of buffers because implementing backends may encode their graph IR in - /// parts (e.g., OpenVINO stores its IR and weights separately). - type graph-builder = list - type graph-builder-array = list - - /// An execution graph for performing inference (i.e., a model). - /// - /// TODO: replace with `resource` - type graph = u32 - - /// Describes the encoding of the graph. This allows the API to be implemented by various backends - /// that encode (i.e., serialize) their graph IR with different formats. - enum graph-encoding { - openvino, - onnx, - tensorflow, - pytorch, - tensorflowlite - } - - /// Define where the graph should be executed. - enum execution-target { - cpu, - gpu, - tpu - } - - /// Bind a `graph` to the input and output tensors for an inference. - /// - /// TODO: replace with `resource` - /// TODO: remove execution contexts completely - type graph-execution-context = u32 - - /// Error codes returned by functions in this API. - enum error { - /// No error occurred. - success, - /// Caller module passed an invalid argument. - invalid-argument, - /// Invalid encoding. - invalid-encoding, - /// Caller module is missing a memory export. - missing-memory, - /// Device or resource busy. - busy, - /// Runtime Error. - runtime-error, - } -} diff --git a/crates/wasi-nn/wit/world.wit b/crates/wasi-nn/wit/world.wit deleted file mode 100644 index 42bffb93420c..000000000000 --- a/crates/wasi-nn/wit/world.wit +++ /dev/null @@ -1,20 +0,0 @@ -/// `wasi-nn` API -/// -/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The -/// API is not (yet) capable of performing ML training. WebAssembly programs -/// that want to use a host's ML capabilities can access these capabilities -/// through `wasi-nn`'s core abstractions: _backends_, _graphs_, and _tensors_. -/// A user selects a _backend_ for inference and `load`s a model, instantiated -/// as a _graph_, to use in the _backend_. Then, the user passes _tensor_ inputs -/// to the _graph_, computes the inference, and retrieves the _tensor_ outputs. -/// -/// This module draws inspiration from the inference side of -/// [WebNN](https://webmachinelearning.github.io/webnn/#api). See the -/// [README](https://github.com/WebAssembly/wasi-nn/blob/main/README.md) for -/// more context about the design and goals of this API. - -package wasi:nn - -world ml { - import inference -} diff --git a/src/commands/run.rs b/src/commands/run.rs index 2f0795fb57f3..2c9485d048a9 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -741,7 +741,7 @@ fn populate_with_wasi( Arc::get_mut(host.wasi_nn.as_mut().unwrap()) .expect("wasi-nn is not implemented with multi-threading support") })?; - store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new()?)); + store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::default())); } }