Skip to content

Commit

Permalink
wasi-nn: update upstream specification (bytecodealliance#6853)
Browse files Browse the repository at this point in the history
This change removes the temporary `*.wit` files and bumps the `spec`
directory to the latest [wasi-nn] commit. This is now possible because
the upstream `spec` repository has all of the updated WIT and WITX bits.
The `load_by_name` implementations are left as TODOs for now and will be
included in a subsequent PR.

One other change is a refactoring: we wrap up some types--`Graph` and
`ExecutionContext`--to avoid passing around `Box<dyn ...>`. This
simplifies some of the code but should not change behavior.

Run all tests: prtest:full

[wasi-nn]: https://github.com/WebAssembly/wasi-nn.
  • Loading branch information
abrown authored Aug 16, 2023
1 parent 6130395 commit 11af6af
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 201 deletions.
2 changes: 1 addition & 1 deletion crates/wasi-nn/spec
Submodule spec updated 120 files
21 changes: 9 additions & 12 deletions crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,30 @@ mod openvino;

use self::openvino::OpenvinoBackend;
use crate::wit::types::{ExecutionTarget, Tensor};
use crate::{ExecutionContext, Graph};
use thiserror::Error;
use wiggle::GuestError;

/// Return a list of all available backend frameworks.
pub(crate) fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
pub fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
}

/// A [Backend] contains the necessary state to load [BackendGraph]s.
pub(crate) trait Backend: Send + Sync {
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn load(
&mut self,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError>;
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
}

/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
/// implementation for a [crate::witx::types::Graph].
pub(crate) trait BackendGraph: Send + Sync {
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError>;
pub trait BackendGraph: Send + Sync {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError>;
}

/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
pub(crate) trait BackendExecutionContext: Send + Sync {
pub trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
Expand All @@ -52,7 +49,7 @@ pub enum BackendError {
NotEnoughMemory(usize),
}

#[derive(Hash, PartialEq, Eq, Clone, Copy)]
pub(crate) enum BackendKind {
#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)]
pub enum BackendKind {
OpenVINO,
}
22 changes: 10 additions & 12 deletions crates/wasi-nn/src/backend/openvino.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use super::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::wit::types::{ExecutionTarget, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::Arc;

Expand All @@ -15,11 +16,7 @@ impl Backend for OpenvinoBackend {
"openvino"
}

fn load(
&mut self,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError> {
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
if builders.len() != 2 {
return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()).into());
}
Expand Down Expand Up @@ -54,8 +51,9 @@ impl Backend for OpenvinoBackend {

let exec_network =
core.load_network(&cnn_network, map_execution_target_to_string(target))?;

Ok(Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network)))
let box_: Box<dyn BackendGraph> =
Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network));
Ok(box_.into())
}
}

Expand All @@ -65,12 +63,11 @@ unsafe impl Send for OpenvinoGraph {}
unsafe impl Sync for OpenvinoGraph {}

impl BackendGraph for OpenvinoGraph {
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError> {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError> {
let infer_request = self.1.create_infer_request()?;
Ok(Box::new(OpenvinoExecutionContext(
self.0.clone(),
infer_request,
)))
let box_: Box<dyn BackendExecutionContext> =
Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request));
Ok(box_.into())
}
}

Expand Down Expand Up @@ -145,5 +142,6 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision
TensorType::Fp32 => Precision::FP32,
TensorType::U8 => Precision::U8,
TensorType::I32 => Precision::I32,
TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"),
}
}
34 changes: 18 additions & 16 deletions crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].

use crate::backend::{
self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind,
};
use crate::backend::{self, Backend, BackendError, BackendKind};
use crate::wit::types::GraphEncoding;
use std::collections::HashMap;
use std::hash::Hash;
use crate::{ExecutionContext, Graph};
use std::{collections::HashMap, hash::Hash};
use thiserror::Error;
use wiggle::GuestError;

type Backends = HashMap<BackendKind, Box<dyn Backend>>;
type GraphId = u32;
type GraphExecutionContextId = u32;

/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<BackendKind, Box<dyn Backend>>,
pub(crate) graphs: Table<GraphId, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContextId, Box<dyn BackendExecutionContext>>,
pub(crate) backends: Backends,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}

impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new() -> WasiNnResult<Self> {
let mut backends = HashMap::new();
for (kind, backend) in backend::list() {
backends.insert(kind, backend);
}
Ok(Self {
pub fn new(backends: Backends) -> Self {
Self {
backends,
graphs: Table::default(),
executions: Table::default(),
})
}
}
}
impl Default for WasiNnCtx {
fn default() -> Self {
WasiNnCtx::new(backend::list().into_iter().collect())
}
}

Expand Down Expand Up @@ -59,6 +59,8 @@ pub enum UsageError {
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
#[error("No graph found with name: {0}")]
NotFound(String),
}

pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
Expand Down Expand Up @@ -105,6 +107,6 @@ mod test {

#[test]
fn instantiate() {
WasiNnCtx::new().unwrap();
WasiNnCtx::default();
}
}
38 changes: 38 additions & 0 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,41 @@ mod ctx;
pub use ctx::WasiNnCtx;
pub mod wit;
pub mod witx;

/// A backend-defined graph (i.e., ML model).
pub struct Graph(Box<dyn backend::BackendGraph>);
impl From<Box<dyn backend::BackendGraph>> for Graph {
fn from(value: Box<dyn backend::BackendGraph>) -> Self {
Self(value)
}
}
impl std::ops::Deref for Graph {
type Target = dyn backend::BackendGraph;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Graph {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}

/// A backend-defined execution context.
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
Self(value)
}
}
impl std::ops::Deref for ExecutionContext {
type Target = dyn backend::BackendExecutionContext;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for ExecutionContext {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
57 changes: 36 additions & 21 deletions crates/wasi-nn/src/wit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,29 @@

use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx};

pub use gen::types;
pub use gen_::Ml as ML;

/// Generate the traits and types from the `wasi-nn` WIT specification.
mod gen_ {
wasmtime::component::bindgen!("ml");
wasmtime::component::bindgen!("ml" in "spec/wit/wasi-nn.wit");
}
use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need.

impl gen::inference::Host for WasiNnCtx {
// Export the `types` used in this crate as well as `ML::add_to_linker`.
pub mod types {
use super::gen;
pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding};
pub use gen::inference::GraphExecutionContext;
pub use gen::tensor::{Tensor, TensorType};
}
pub use gen_::Ml as ML;

impl gen::graph::Host for WasiNnCtx {
/// Load an opaque sequence of bytes to use for inference.
fn load(
&mut self,
builders: gen::types::GraphBuilderArray,
encoding: gen::types::GraphEncoding,
target: gen::types::ExecutionTarget,
) -> wasmtime::Result<Result<gen::types::Graph, gen::types::Error>> {
builders: Vec<gen::graph::GraphBuilder>,
encoding: gen::graph::GraphEncoding,
target: gen::graph::ExecutionTarget,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
let backend_kind: BackendKind = encoding.try_into()?;
let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
Expand All @@ -45,13 +51,22 @@ impl gen::inference::Host for WasiNnCtx {
Ok(Ok(graph_id))
}

fn load_by_name(
&mut self,
_name: String,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
todo!()
}
}

impl gen::inference::Host for WasiNnCtx {
/// Create an execution instance of a loaded graph.
///
/// TODO: remove completely?
fn init_execution_context(
&mut self,
graph_id: gen::types::Graph,
) -> wasmtime::Result<Result<gen::types::GraphExecutionContext, gen::types::Error>> {
graph_id: gen::graph::Graph,
) -> wasmtime::Result<Result<gen::inference::GraphExecutionContext, gen::errors::Error>> {
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
graph.init_execution_context()?
} else {
Expand All @@ -65,10 +80,10 @@ impl gen::inference::Host for WasiNnCtx {
/// Define the inputs to use for inference.
fn set_input(
&mut self,
exec_context_id: gen::types::GraphExecutionContext,
exec_context_id: gen::inference::GraphExecutionContext,
index: u32,
tensor: gen::types::Tensor,
) -> wasmtime::Result<Result<(), gen::types::Error>> {
tensor: gen::tensor::Tensor,
) -> wasmtime::Result<Result<(), gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
exec_context.set_input(index, &tensor)?;
Ok(Ok(()))
Expand All @@ -82,8 +97,8 @@ impl gen::inference::Host for WasiNnCtx {
/// TODO: refactor to compute(list<tensor>) -> result<list<tensor>, error>
fn compute(
&mut self,
exec_context_id: gen::types::GraphExecutionContext,
) -> wasmtime::Result<Result<(), gen::types::Error>> {
exec_context_id: gen::inference::GraphExecutionContext,
) -> wasmtime::Result<Result<(), gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
exec_context.compute()?;
Ok(Ok(()))
Expand All @@ -95,9 +110,9 @@ impl gen::inference::Host for WasiNnCtx {
/// Extract the outputs after inference.
fn get_output(
&mut self,
exec_context_id: gen::types::GraphExecutionContext,
exec_context_id: gen::inference::GraphExecutionContext,
index: u32,
) -> wasmtime::Result<Result<gen::types::TensorData, gen::types::Error>> {
) -> wasmtime::Result<Result<gen::tensor::TensorData, gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
// Read the output bytes. TODO: this involves a hard-coded upper
// limit on the tensor size that is necessary because there is no
Expand All @@ -113,11 +128,11 @@ impl gen::inference::Host for WasiNnCtx {
}
}

impl TryFrom<gen::types::GraphEncoding> for crate::backend::BackendKind {
impl TryFrom<gen::graph::GraphEncoding> for crate::backend::BackendKind {
type Error = UsageError;
fn try_from(value: gen::types::GraphEncoding) -> Result<Self, Self::Error> {
fn try_from(value: gen::graph::GraphEncoding) -> Result<Self, Self::Error> {
match value {
gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO),
gen::graph::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO),
_ => Err(UsageError::InvalidEncoding(value.into())),
}
}
Expand Down
Loading

0 comments on commit 11af6af

Please sign in to comment.