Skip to content

Commit

Permalink
wasi-nn: track upstream specification
Browse files Browse the repository at this point in the history
In bytecodealliance#8873, we stopped tracking the wasi-nn's upstream WIT files
temporarily because it was not clear to me at the time how to implement
errors as CM resources. This PR fixes that, resuming tracking in the
`vendor-wit.sh` and implementing what is needed in the wasi-nn crate.

This leaves several threads unresolved, though:
- it looks like the `vendor-wit.sh` has a new mechanism for retrieving
  files into `wit/deps`--at some point wasi-nn should migrate to use
  that (?)
- it's not clear to me that "errors as resources" is even the best
  approach here; I've opened [bytecodealliance#75] to consider the possibility of using
  "errors as records" instead.

[bytecodealliance#75]: WebAssembly/wasi-nn#75

prtest:full
  • Loading branch information
abrown committed Aug 1, 2024
1 parent ba864e9 commit d7a02b6
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 49 deletions.
8 changes: 2 additions & 6 deletions ci/vendor-wit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ rm -rf $cache_dir
# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is
# slightly different than above.
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
revision=e2310b
revision=0.2.0-rc-2024-06-25
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
# latest WIT specification on `main`. To create a baseline for moving forward,
# the in-tree WIT incorporates some but not all of the upstream changes. This
# TODO can be removed once the implementation catches up with the spec.
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
2 changes: 2 additions & 0 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ impl std::ops::DerefMut for ExecutionContext {
}
}



/// A container for graphs.
pub struct Registry(Box<dyn GraphRegistry>);
impl std::ops::Deref for Registry {
Expand Down
141 changes: 101 additions & 40 deletions crates/wasi-nn/src/wit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::backend::Id;
use crate::{Backend, Registry};
use anyhow::anyhow;
use std::collections::HashMap;
use std::hash::Hash;
use std::{fmt, str::FromStr};
Expand Down Expand Up @@ -54,7 +55,26 @@ impl<'a> WasiNnView<'a> {
}
}

pub enum Error {
#[derive(Debug)]
pub struct Error {
code: ErrorCode,
data: anyhow::Error,
}

macro_rules! bail {
($self:ident, $code:expr, $data:expr) => {
let e = Error {
code: $code,
data: $data.into(),
};
tracing::error!("failure: {e:?}");
let r = $self.table.push(e)?;
return Ok(Err(r));
};
}

#[derive(Debug)]
pub enum ErrorCode {
/// Caller module passed an invalid argument.
InvalidArgument,
/// Invalid encoding.
Expand All @@ -70,12 +90,15 @@ pub enum Error {
/// Graph not found.
NotFound,
/// A runtime error occurred that we should trap on; see `StreamError`.
Trap(anyhow::Error),
Trap,
}

impl From<wasmtime::component::ResourceTableError> for Error {
fn from(error: wasmtime::component::ResourceTableError) -> Self {
Self::Trap(error.into())
Self {
code: ErrorCode::Trap,
data: error.into(),
}
}
}

Expand All @@ -91,6 +114,7 @@ mod gen_ {
"wasi:nn/graph/graph": crate::Graph,
"wasi:nn/tensor/tensor": crate::Tensor,
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
"wasi:nn/errors/error": super::Error,
},
trappable_error_type: {
"wasi:nn/errors/error" => super::Error,
Expand Down Expand Up @@ -131,36 +155,45 @@ impl gen::graph::Host for WasiNnView<'_> {
builders: Vec<GraphBuilder>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Resource<crate::Graph>, Error> {
) -> Result<Result<Resource<crate::Graph>, Resource<Error>>, anyhow::Error> {
tracing::debug!("load {encoding:?} {target:?}");
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
match backend.load(&slices, target.into()) {
Ok(graph) => {
let graph = self.table.push(graph)?;
Ok(graph)
Ok(Ok(graph))
}
Err(error) => {
tracing::error!("failed to load graph: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
} else {
Err(Error::InvalidEncoding)
bail!(
self,
ErrorCode::InvalidEncoding,
anyhow!("unable to find a backend for this encoding")
);
}
}

fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
fn load_by_name(
&mut self,
name: String,
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
use core::result::Result::*;
tracing::debug!("load by name {name:?}");
let registry = &self.ctx.registry;
if let Some(graph) = registry.get(&name) {
let graph = graph.clone();
let graph = self.table.push(graph)?;
Ok(graph)
Ok(Ok(graph))
} else {
tracing::error!("failed to find graph with name: {name}");
Err(Error::NotFound)
bail!(
self,
ErrorCode::NotFound,
anyhow!("failed to find graph with name: {name}")
);
}
}
}
Expand All @@ -169,18 +202,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
fn init_execution_context(
&mut self,
graph: Resource<Graph>,
) -> Result<Resource<GraphExecutionContext>, Error> {
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
use core::result::Result::*;
tracing::debug!("initialize execution context");
let graph = self.table.get(&graph)?;
match graph.init_execution_context() {
Ok(exec_context) => {
let exec_context = self.table.push(exec_context)?;
Ok(exec_context)
Ok(Ok(exec_context))
}
Err(error) => {
tracing::error!("failed to initialize execution context: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
Expand All @@ -197,27 +229,28 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
exec_context: Resource<GraphExecutionContext>,
name: String,
tensor: Resource<Tensor>,
) -> Result<(), Error> {
) -> wasmtime::Result<Result<(), Resource<Error>>> {
let tensor = self.table.get(&tensor)?;
tracing::debug!("set input {name:?}: {tensor:?}");
let tensor = tensor.clone(); // TODO: avoid copying the tensor
let exec_context = self.table.get_mut(&exec_context)?;
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
tracing::error!("failed to set input: {e:?}");
Err(Error::InvalidArgument)
if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
bail!(self, ErrorCode::InvalidArgument, error);
} else {
Ok(())
Ok(Ok(()))
}
}

fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
fn compute(
&mut self,
exec_context: Resource<GraphExecutionContext>,
) -> wasmtime::Result<Result<(), Resource<Error>>> {
let exec_context = &mut self.table.get_mut(&exec_context)?;
tracing::debug!("compute");
match exec_context.compute() {
Ok(()) => Ok(()),
Ok(()) => Ok(Ok(())),
Err(error) => {
tracing::error!("failed to compute: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
Expand All @@ -227,17 +260,16 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
&mut self,
exec_context: Resource<GraphExecutionContext>,
name: String,
) -> Result<Resource<Tensor>, Error> {
) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
let exec_context = self.table.get_mut(&exec_context)?;
tracing::debug!("get output {name:?}");
match exec_context.get_output(Id::Name(name)) {
Ok(tensor) => {
let tensor = self.table.push(tensor)?;
Ok(tensor)
Ok(Ok(tensor))
}
Err(error) => {
tracing::error!("failed to get output: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
Expand Down Expand Up @@ -285,21 +317,50 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
}
}

impl gen::tensor::Host for WasiNnView<'_> {}
impl gen::errors::HostError for WasiNnView<'_> {
fn new(
&mut self,
_code: gen::errors::ErrorCode,
_data: wasmtime::component::__internal::String,
) -> wasmtime::Result<wasmtime::component::Resource<gen::errors::Error>> {
unimplemented!()
}

fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<gen::errors::ErrorCode> {
let error = self.table.get(&error)?;
match error.code {
ErrorCode::InvalidArgument => Ok(gen::errors::ErrorCode::InvalidArgument),
ErrorCode::InvalidEncoding => Ok(gen::errors::ErrorCode::InvalidEncoding),
ErrorCode::Timeout => Ok(gen::errors::ErrorCode::Timeout),
ErrorCode::RuntimeError => Ok(gen::errors::ErrorCode::RuntimeError),
ErrorCode::UnsupportedOperation => Ok(gen::errors::ErrorCode::UnsupportedOperation),
ErrorCode::TooLarge => Ok(gen::errors::ErrorCode::TooLarge),
ErrorCode::NotFound => Ok(gen::errors::ErrorCode::NotFound),
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
}
}

fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
let error = self.table.get(&error)?;
Ok(error.data.to_string())
}

fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
self.table.delete(error)?;
Ok(())
}
}

impl gen::errors::Host for WasiNnView<'_> {
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
match err {
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
Error::Timeout => Ok(gen::errors::Error::Timeout),
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
Error::NotFound => Ok(gen::errors::Error::NotFound),
Error::Trap(e) => Err(e),
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
if matches!(err.code, ErrorCode::Trap) {
Err(err.data)
} else {
Ok(err)
}
}
}
impl gen::tensor::Host for WasiNnView<'_> {}
impl gen::inference::Host for WasiNnView<'_> {}

impl Hash for gen::graph::GraphEncoding {
Expand Down
21 changes: 18 additions & 3 deletions crates/wasi-nn/wit/wasi-nn.wit
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package wasi:nn;
package wasi:nn@0.2.0-rc-2024-06-25;

/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
Expand Down Expand Up @@ -134,7 +134,7 @@ interface inference {

/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
interface errors {
enum error {
enum error-code {
// Caller module passed an invalid argument.
invalid-argument,
// Invalid encoding.
Expand All @@ -148,6 +148,21 @@ interface errors {
// Graph is too large.
too-large,
// Graph not found.
not-found
not-found,
// The operation is insecure or has insufficient privilege to be performed.
// e.g., cannot access a hardware feature requested
security,
// The operation failed for an unspecified reason.
unknown
}

resource error {
constructor(code: error-code, data: string);

/// Return the error code.
code: func() -> error-code;

/// Errors can propagated with backend specific status through a string value.
data: func() -> string;
}
}

0 comments on commit d7a02b6

Please sign in to comment.