Skip to content

Commit

Permalink
wasi-nn: refactor to allow preview2 access (bytecodealliance#6821)
Browse files Browse the repository at this point in the history
* wasi-nn: refactor to allow `preview2` access

This change refactors the `wasmtime-wasi-nn` crate to allow access from
both `preview1` and `preview2` ABIs. Though the `wasi-nn` specification
has included a WIT description for some time, here we use some in-tree
files until WebAssembly/wasi-nn#38 is landed.
The `preview2` code is not exercised anywhere yet: ideally this would be
wired up once component model `resource`s are fully implemented in
Wasmtime.

prtest:full

* wasi-nn: use `preview1` linkage

prtest:full

* review: rename `preview*` to `wit*`

This is based on @pchickey's [comments] on ABI naming.

[comments]: https://bytecodealliance.zulipchat.com/#narrow/stream/266558-wasi-nn/topic/wasi-nn.20.2B.20preview2/near/383368292

* review: update README

* fix: remove broken doc links

* fix: replace typo `wit` with `gen`

* review: use `wit` types everywhere

This removes the crate-specific types in order to use the WIT-generated
types throughout the crate. The main effect of this is that the crate
no longer optionally includes `wasmtime` with the `component-model`
feature--now that is required.

* review: move `BackendKind` conversion into `witx.rs`

* review: remove `<'a>`

* review: use `tracing` crate instead of `eprintln!`
  • Loading branch information
abrown authored Aug 16, 2023
1 parent ca5a9db commit 6130395
Show file tree
Hide file tree
Showing 15 changed files with 524 additions and 190 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ readme = "README.md"
edition.workspace = true

[dependencies]
# These dependencies are necessary for the witx-generation macros to work:
# These dependencies are necessary for the WITX-generation macros to work:
anyhow = { workspace = true }
wiggle = { workspace = true }

# This dependency is necessary for the WIT-generation macros to work:
wasmtime = { workspace = true, features = ["component-model"] }

# These dependencies are necessary for the wasi-nn implementation:
tracing = { workspace = true }
openvino = { version = "0.5.0", features = ["runtime-linking"] }
thiserror = { workspace = true }

Expand Down
37 changes: 22 additions & 15 deletions crates/wasi-nn/README.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,45 @@
# wasmtime-wasi-nn

This crate enables support for the [wasi-nn] API in Wasmtime. Currently it contains an implementation of [wasi-nn] using
OpenVINO™ but in the future it could support multiple machine learning backends. Since the [wasi-nn] API is expected
to be an optional feature of WASI, this crate is currently separate from the [wasi-common] crate. This crate is
experimental and its API, functionality, and location could quickly change.
This crate enables support for the [wasi-nn] API in Wasmtime. Currently it
contains an implementation of [wasi-nn] using OpenVINO™ but in the future it
could support multiple machine learning backends. Since the [wasi-nn] API is
expected to be an optional feature of WASI, this crate is currently separate
from the [wasi-common] crate. This crate is experimental and its API,
functionality, and location could quickly change.

[examples]: examples
[openvino]: https://crates.io/crates/openvino
[wasi-nn]: https://github.com/WebAssembly/wasi-nn
[wasi-common]: ../wasi-common
[bindings]: https://crates.io/crates/wasi-nn

### Use

Use the Wasmtime APIs to instantiate a Wasm module and link in the `WasiNn` implementation as follows:
Use the Wasmtime APIs to instantiate a Wasm module and link in the `wasi-nn`
implementation as follows:

```
let wasi_nn = WasiNn::new(&store, WasiNnCtx::new()?);
wasi_nn.add_to_linker(&mut linker)?;
```rust
let wasi_nn = WasiNnCtx::new()?;
wasmtime_wasi_nn::witx::add_to_linker(...);
```

### Build

This crate should build as usual (i.e. `cargo build`) but note that using an existing installation of OpenVINO™, rather
than building from source, will drastically improve the build times. See the [openvino] crate for more information
```sh
$ cargo build
```

To use the WIT-based ABI, compile with `--features component-model` and use `wasmtime_wasi_nn::wit::add_to_linker`.

### Example

An end-to-end example demonstrating ML classification is included in [examples]:
- `tests/wasi-nn-rust-bindings` contains ergonomic bindings for writing Rust code against the [wasi-nn] APIs
- `tests/classification-example` contains a standalone Rust project that uses the [wasi-nn] APIs and is compiled to the
`wasm32-wasi` target using the `wasi-nn-rust-bindings`
`examples/classification-example` contains a standalone Rust project that uses
the [wasi-nn] APIs and is compiled to the `wasm32-wasi` target using the
high-level `wasi-nn` [bindings].

Run the example from the Wasmtime project directory:

```
ci/run-wasi-nn-example.sh
```sh
$ ci/run-wasi-nn-example.sh
```
4 changes: 3 additions & 1 deletion crates/wasi-nn/examples/classification-example/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 18 additions & 5 deletions crates/wasi-nn/src/api.rs → crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
//! Define the Rust interface a backend must implement in order to be used by
//! this crate. the `Box<dyn ...>` types returned by these interfaces allow
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
//! implementations to maintain backend-specific state between calls.

use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor};
mod openvino;

use self::openvino::OpenvinoBackend;
use crate::wit::types::{ExecutionTarget, Tensor};
use thiserror::Error;
use wiggle::GuestError;

/// Return a list of all available backend frameworks.
pub(crate) fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
}

/// A [Backend] contains the necessary state to load [BackendGraph]s.
pub(crate) trait Backend: Send + Sync {
fn name(&self) -> &str;
fn load(
&mut self,
builders: &GraphBuilderArray<'_>,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError>;
}
Expand All @@ -25,7 +33,7 @@ pub(crate) trait BackendGraph: Send + Sync {
/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
pub(crate) trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>;
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
}
Expand All @@ -39,7 +47,12 @@ pub enum BackendError {
#[error("Failed while accessing guest module")]
GuestAccess(#[from] GuestError),
#[error("The backend expects {0} buffers, passed {1}")]
InvalidNumberOfBuilders(u32, u32),
InvalidNumberOfBuilders(usize, usize),
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(usize),
}

#[derive(Hash, PartialEq, Eq, Clone, Copy)]
pub(crate) enum BackendKind {
OpenVINO,
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
//! Implements the wasi-nn API.
//! Implements a `wasi-nn` [`Backend`] using OpenVINO.

use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor, TensorType};
use super::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::wit::types::{ExecutionTarget, Tensor, TensorType};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::Arc;

#[derive(Default)]
pub(crate) struct OpenvinoBackend(Option<openvino::Core>);

unsafe impl Send for OpenvinoBackend {}
unsafe impl Sync for OpenvinoBackend {}

Expand All @@ -18,7 +17,7 @@ impl Backend for OpenvinoBackend {

fn load(
&mut self,
builders: &GraphBuilderArray<'_>,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError> {
if builders.len() != 2 {
Expand All @@ -34,16 +33,8 @@ impl Backend for OpenvinoBackend {
}

// Read the guest array.
let builders = builders.as_ptr();
let xml = builders
.read()?
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let weights = builders
.add(1)?
.read()?
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let xml = &builders[0];
let weights = &builders[1];

// Construct OpenVINO graph structures: `cnn_network` contains the graph
// structure, `exec_network` can perform inference.
Expand All @@ -53,8 +44,9 @@ impl Backend for OpenvinoBackend {
.expect("openvino::Core was previously constructed");
let mut cnn_network = core.read_network_from_buffer(&xml, &weights)?;

// TODO this is a temporary workaround. We need a more eligant way to specify the layout in the long run.
// However, without this newer versions of OpenVINO will fail due to parameter mismatch.
// TODO: this is a temporary workaround. We need a more elegant way to
// specify the layout in the long run. However, without this newer
// versions of OpenVINO will fail due to parameter mismatch.
for i in 0..cnn_network.get_inputs_len()? {
let name = cnn_network.get_input_name(i)?;
cnn_network.set_input_layout(&name, Layout::NHWC)?;
Expand Down Expand Up @@ -85,27 +77,19 @@ impl BackendGraph for OpenvinoGraph {
struct OpenvinoExecutionContext(Arc<openvino::CNNNetwork>, openvino::InferRequest);

impl BackendExecutionContext for OpenvinoExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError> {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> {
let input_name = self.0.get_input_name(index as usize)?;

// Construct the blob structure.
// Construct the blob structure. TODO: there must be some good way to
// discover the layout here; `desc` should not have to default to NHWC.
let precision = map_tensor_type_to_precision(tensor.tensor_type);
let dimensions = tensor
.dimensions
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)")
.iter()
.map(|d| *d as usize)
.map(|&d| d as usize)
.collect::<Vec<_>>();
let precision = map_tensor_type_to_precision(tensor.type_);

// TODO There must be some good way to discover the layout here; this
// should not have to default to NHWC.
let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision);
let data = tensor
.data
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let blob = openvino::Blob::new(&desc, &data)?;
let blob = openvino::Blob::new(&desc, &tensor.data)?;

// Actually assign the blob to the request.
self.1.set_blob(&input_name, &blob)?;
Expand Down Expand Up @@ -157,8 +141,8 @@ fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str {
/// wasi-nn.
fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision {
match tensor_type {
TensorType::F16 => Precision::FP16,
TensorType::F32 => Precision::FP32,
TensorType::Fp16 => Precision::FP16,
TensorType::Fp32 => Precision::FP32,
TensorType::U8 => Precision::U8,
TensorType::I32 => Precision::I32,
}
Expand Down
46 changes: 31 additions & 15 deletions crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the
//! implementation of the wasi-nn API.
use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::openvino::OpenvinoBackend;
use crate::r#impl::UsageError;
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].

use crate::backend::{
self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind,
};
use crate::wit::types::GraphEncoding;
use std::collections::HashMap;
use std::hash::Hash;
use thiserror::Error;
use wiggle::GuestError;

type GraphId = u32;
type GraphExecutionContextId = u32;

/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
pub(crate) backends: HashMap<BackendKind, Box<dyn Backend>>,
pub(crate) graphs: Table<GraphId, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContextId, Box<dyn BackendExecutionContext>>,
}

impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new() -> WasiNnResult<Self> {
let mut backends = HashMap::new();
backends.insert(
// This is necessary because Wiggle's variant types do not derive
// `Hash` and `Eq`.
GraphEncoding::Openvino.into(),
Box::new(OpenvinoBackend::default()) as Box<dyn Backend>,
);
for (kind, backend) in backend::list() {
backends.insert(kind, backend);
}
Ok(Self {
backends,
graphs: Table::default(),
Expand All @@ -45,6 +45,22 @@ pub enum WasiNnError {
UsageError(#[from] UsageError),
}

#[derive(Debug, Error)]
pub enum UsageError {
#[error("Invalid context; has the load function been called?")]
InvalidContext,
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")]
InvalidNumberOfBuilders(u32),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
}

pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;

/// Record handle entries in a table.
Expand Down
Loading

0 comments on commit 6130395

Please sign in to comment.