diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml new file mode 100644 index 000000000000..7897e21046fa --- /dev/null +++ b/rust/tvm-sys/Cargo.toml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-sys" +version = "0.1.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +edition = "2018" + +[features] +bindings = [] + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +ndarray = "0.12" + +[build-dependencies] +bindgen = "0.51" diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs new file mode 100644 index 000000000000..915827bf95f0 --- /dev/null +++ b/rust/tvm-sys/build.rs @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate bindgen; + +use std::path::PathBuf; + +// extern crate cmake; + +use std::env; +// use std::path::Path; +// use std::process::Command; +// use cmake::Config; + +// fn main() { +// if !Path::new("tvm/.git").exists() { +// let _ = Command::new("git") +// .args(&["submodule", "update", "--recursive", "--init"]) +// .status(); +// } + +// let dst = Config::new("tvm") +// .very_verbose(true) +// .build(); + +// // let dst = dst.join("build"); + +// let out_dir = env::var("OUT_DIR").unwrap(); + +// println!("{}", out_dir); +// // let _ = Command::new("mv") +// // .args(&[format!("{}/build/libtvm.dylib", dst.display()), out_dir]) +// // .status(); + +// println!("cargo:rustc-link-search=native={}/lib", dst.display()); +// // TODO(@jroesch): hack for dylib behavior +// for lib in &[/* "tvm", */ "tvm_runtime", /* "tvm_topi" */] { +// // let src = format!("{}/lib/lib{}.dylib", out_dir, lib); +// // let dst = format!("{}/../../../deps", out_dir); +// // let _ = Command::new("mv") +// // .args(&[src, dst]) +// // .status(); +// println!("cargo:rustc-link-lib=dylib={}", lib); +// } +// // "-Wl,-rpath,/scratch/library/" +// println!("cargo:rustc-env=TVM_HOME={}/build", dst.display()); +// // panic!(""); +// // cc::Build::new() +// // .cpp(true) +// // .flag("-std=c++11") +// // .flag("-Wno-ignored-qualifiers") +// // .flag("-Wno-unused-parameter") +// // .include("/Users/jroesch/Git/tvm/include") +// // .include("/Users/jroesch/Git/tvm/3rdparty/dmlc-core/include") +// // .include("/Users/jroesch/Git/tvm/3rdparty/dlpack/include") +// // .include("/Users/jroesch/Git/tvm/3rdparty/HalideIR/src") +// // .file("tvm_wrapper.cc") +// // .compile("tvm_ffi"); +// // println!("cargo:rustc-link-lib=dylib=tvm"); +// // println!("cargo:rustc-link-search=/Users/jroesch/Git/tvm/build"); +// } + +fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + crate_dir + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); + + if cfg!(feature = "bindings") { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + // println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + // TODO: move to core + // println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-search={}/build", tvm_home); + } + + // @see rust-bindgen#550 for `blacklist_type` + bindgen::Builder::default() + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) + .clang_arg(format!("-I{}/include/", tvm_home)) + .blacklist_type("max_align_t") + .layout_tests(false) + .derive_partialeq(true) + .derive_eq(true) + .generate() + .expect("unable to generate bindings") + .write_to_file(PathBuf::from("src/c_runtime_api.rs")) + .expect("can not write the bindings!"); +} diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs new file mode 100644 index 000000000000..be7b4916b732 --- /dev/null +++ b/rust/tvm-sys/src/array.rs @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + mem, + os::raw::{c_int, c_void}, +}; + +use crate::ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + }, + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs new file mode 100644 index 000000000000..b72c91b1c637 --- /dev/null +++ b/rust/tvm-sys/src/byte_array.rs @@ -0,0 +1,64 @@ +use std::os::raw::c_char; + +use crate::ffi::TVMByteArray; + +/// A struct holding TVM byte-array. +/// +/// ## Example +/// +/// ``` +/// let v = b"hello"; +/// let barr = tvm_sys::ByteArray::from(&v); +/// assert_eq!(barr.len(), v.len()); +/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); +/// ``` +pub type ByteArray = TVMByteArray; + +impl ByteArray { + /// Gets the underlying byte-array + pub fn data(&self) -> &'static [u8] { + unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) } + } + + /// Gets the length of the underlying byte-array + pub fn len(&self) -> usize { + self.size + } + + /// Converts the underlying byte-array to `Vec` + pub fn to_vec(&self) -> Vec { + self.data().to_vec() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +// Needs AsRef for Vec +impl> From for ByteArray { + fn from(arg: T) -> Self { + let arg = arg.as_ref(); + ByteArray { + data: arg.as_ptr() as *const c_char, + size: arg.len(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert() { + let v = vec![1u8, 2, 3]; + let barr = ByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); + let v = b"hello"; + let barr = ByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); + } +} diff --git a/rust/tvm-sys/src/context.rs b/rust/tvm-sys/src/context.rs new file mode 100644 index 000000000000..1952b32f62ed --- /dev/null +++ b/rust/tvm-sys/src/context.rs @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides [`Context`] and related device queries. +//! +//! Create a new context for device type and device id. +//! +//! # Example +//! +//! ``` +//! # use tvm_sys::{TVMDeviceType, Context}; +//! let cpu = TVMDeviceType::from("cpu"); +//! let ctx = Context::new(cpu , 0); +//! let cpu0 = Context::cpu(0); +//! assert_eq!(ctx, cpu0); +//! ``` +//! +//! Or from a supported device name. +//! +//! ``` +//! use tvm_sys::Context; +//! let cpu0 = Context::from("cpu"); +//! println!("{}", cpu0); +//! ``` + +use crate::ffi::{self, *}; +use crate::packed_func::{ArgValue, RetValue}; + +use std::convert::TryFrom; +use std::str::FromStr; +use thiserror::Error; + +use std::fmt::{self, Display, Formatter}; + +use anyhow::Result; + +/// Device type can be from a supported device name. See the supported devices +/// in [TVM](https://github.com/apache/incubator-tvm). +/// +/// ## Example +/// +/// ``` +/// use tvm_sys::TVMDeviceType; +/// let cpu = TVMDeviceType::from("cpu"); +/// println!("device is: {}", cpu); +///``` + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TVMDeviceType(pub i64); + +impl Default for TVMDeviceType { + /// default device is cpu. + fn default() -> Self { + TVMDeviceType(1) + } +} + +impl From for ffi::DLDeviceType { + fn from(device_type: TVMDeviceType) -> Self { + match device_type.0 { + 1 => ffi::DLDeviceType_kDLCPU, + 2 => ffi::DLDeviceType_kDLGPU, + 3 => ffi::DLDeviceType_kDLCPUPinned, + 4 => ffi::DLDeviceType_kDLOpenCL, + 7 => ffi::DLDeviceType_kDLVulkan, + 8 => ffi::DLDeviceType_kDLMetal, + 9 => ffi::DLDeviceType_kDLVPI, + 10 => ffi::DLDeviceType_kDLROCM, + 12 => ffi::DLDeviceType_kDLExtDev, + _ => panic!("device type not found!"), + } + } +} + +impl From for TVMDeviceType { + fn from(device_type: ffi::DLDeviceType) -> Self { + match device_type { + ffi::DLDeviceType_kDLCPU => TVMDeviceType(1), + ffi::DLDeviceType_kDLGPU => TVMDeviceType(2), + ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), + ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4), + ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7), + ffi::DLDeviceType_kDLMetal => TVMDeviceType(8), + ffi::DLDeviceType_kDLVPI => TVMDeviceType(9), + ffi::DLDeviceType_kDLROCM => TVMDeviceType(10), + ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12), + _ => panic!("device type not found!"), + } + } +} + +impl Display for TVMDeviceType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + TVMDeviceType(1) => "cpu", + TVMDeviceType(2) => "gpu", + TVMDeviceType(3) => "cpu_pinned", + TVMDeviceType(4) => "opencl", + TVMDeviceType(8) => "meta", + TVMDeviceType(9) => "vpi", + TVMDeviceType(10) => "rocm", + TVMDeviceType(_) => "rpc", + } + ) + } +} + +impl<'a> From<&'a str> for TVMDeviceType { + fn from(type_str: &'a str) -> Self { + match type_str { + "cpu" => TVMDeviceType(1), + "llvm" => TVMDeviceType(1), + "stackvm" => TVMDeviceType(1), + "gpu" => TVMDeviceType(2), + "cuda" => TVMDeviceType(2), + "nvptx" => TVMDeviceType(2), + "cl" => TVMDeviceType(4), + "opencl" => TVMDeviceType(4), + "metal" => TVMDeviceType(8), + "vpi" => TVMDeviceType(9), + "rocm" => TVMDeviceType(10), + _ => panic!("{:?} not supported!", type_str), + } + } +} + +impl<'a> From<&TVMDeviceType> for ArgValue<'a> { + fn from(dev: &TVMDeviceType) -> Self { + Self::Int(dev.0) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct Context { + pub device_type: TVMDeviceType, + pub device_id: usize, +} + +impl Context { + pub fn new(device_type: TVMDeviceType, device_id: usize) -> Context { + Context { + device_type, + device_id, + } + } +} + +impl<'a> From<&'a Context> for DLContext { + fn from(ctx: &'a Context) -> Self { + Self { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Default for Context { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU.into(), + device_id: 0, + } + } +} + +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] +pub struct UnsupportedDeviceError(String); + +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a Context from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for Context { + type Err = UnsupportedDeviceError; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type.into()),+, + _ => return Err(UnsupportedDeviceError(type_str.to_string())), + }, + device_id: 0, + }) + } + } + + impl Context { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type.into(), + device_id: device_id, + } + } + )+ + )+ + } + }; +} + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); + +impl<'a> From<&'a str> for Context { + fn from(target: &str) -> Self { + Context::new(TVMDeviceType::from(target), 0) + } +} + +impl From for Context { + fn from(ctx: ffi::DLContext) -> Self { + Context { + device_type: TVMDeviceType::from(ctx.device_type), + device_id: ctx.device_id as usize, + } + } +} + +impl From for ffi::DLContext { + fn from(ctx: Context) -> Self { + ffi::DLContext { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Display for Context { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.device_type, self.device_id) + } +} + +impl From for RetValue { + fn from(ret_value: Context) -> RetValue { + RetValue::Context(ret_value.into()) + } +} + +impl TryFrom for Context { + type Error = anyhow::Error; + fn try_from(ret_value: RetValue) -> anyhow::Result { + match ret_value { + RetValue::Context(dt) => Ok(dt.into()), + // TODO(@jroesch): improve + _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context() { + let ctx = Context::cpu(0); + println!("ctx: {}", ctx); + let default_ctx = Context::new(TVMDeviceType(1), 0); + assert_eq!(ctx.clone(), default_ctx); + assert_ne!(ctx, Context::gpu(0)); + + let str_ctx = Context::new(TVMDeviceType::from("gpu"), 0); + assert_eq!(str_ctx.clone(), str_ctx); + assert_ne!(str_ctx, Context::new(TVMDeviceType::from("cpu"), 0)); + } +} diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs new file mode 100644 index 000000000000..9969d3bcc928 --- /dev/null +++ b/rust/tvm-sys/src/datatype.rs @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::any::TypeId; + +use std::convert::TryFrom; +use std::str::FromStr; + +use crate::packed_func::RetValue; + +use thiserror::Error; + +use crate::ffi::DLDataType; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub code: u8, + pub bits: u8, + pub lanes: u16, +} + +impl DataType { + pub fn new(code: u8, bits: u8, lanes: u16) -> DataType { + DataType { code, bits, lanes } + } + + /// Returns the number of bytes occupied by an element of this `DataType`. + pub fn itemsize(&self) -> usize { + (self.bits as usize * self.lanes as usize) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + pub fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } + + pub fn code(&self) -> usize { + self.code as usize + } + + pub fn bits(&self) -> usize { + self.bits as usize + } + + pub fn lanes(&self) -> usize { + self.lanes as usize + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl From for DataType { + fn from(dtype: DLDataType) -> Self { + Self { + code: dtype.code, + bits: dtype.bits, + lanes: dtype.lanes, + } + } +} + +#[derive(Debug, Error)] +pub enum ParseTvmTypeError { + #[error("invalid number: {0}")] + InvalidNumber(std::num::ParseIntError), + #[error("unknown type: {0}")] + UnknownType(String), +} + +/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` +/// such as "int32", "float32" or with lane "float32x1". +impl FromStr for DataType { + type Err = ParseTvmTypeError; + fn from_str(type_str: &str) -> Result { + if type_str == "bool" { + return Ok(DataType::new(1, 1, 1)); + } + + let mut type_lanes = type_str.split('x'); + let typ = type_lanes.next().expect("Missing dtype"); + let lanes = type_lanes + .next() + .map(|l| ::from_str_radix(l, 10)) + .unwrap_or(Ok(1)) + .map_err(ParseTvmTypeError::InvalidNumber)?; + let (type_name, bits) = match typ.find(char::is_numeric) { + Some(idx) => { + let (name, bits_str) = typ.split_at(idx); + ( + name, + u8::from_str_radix(bits_str, 10).map_err(ParseTvmTypeError::InvalidNumber)?, + ) + } + None => (typ, 32), + }; + + let type_code = match type_name { + "int" => 0, + "uint" => 1, + "float" => 2, + "handle" => 3, + _ => return Err(ParseTvmTypeError::UnknownType(type_name.to_string())), + }; + + Ok(DataType::new(type_code, bits, lanes)) + } +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.bits == 1 && self.lanes == 1 { + return write!(f, "bool"); + } + let mut type_str = match self.code { + 0 => "int", + 1 => "uint", + 2 => "float", + 4 => "handle", + _ => "unknown", + } + .to_string(); + + type_str += &self.bits.to_string(); + if self.lanes > 1 { + type_str += &format!("x{}", self.lanes); + } + f.write_str(&type_str) + } +} + +impl From for RetValue { + fn from(dt: DataType) -> RetValue { + RetValue::DataType((&dt).into()) + } +} + +impl TryFrom for DataType { + type Error = anyhow::Error; + fn try_from(ret_value: RetValue) -> anyhow::Result { + match ret_value { + RetValue::DataType(dt) => Ok(dt.into()), + // TODO(@jroesch): improve + _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), + } + } +} diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs new file mode 100644 index 000000000000..8479ec62f19f --- /dev/null +++ b/rust/tvm-sys/src/errors.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")] +pub struct ValueDowncastError { + pub actual_type: String, + pub expected_type: &'static str, +} + +#[derive(Error, Debug)] +#[error("Function call `{context:?}` returned error: {message:?}")] +pub struct FuncCallError { + context: String, + message: String, +} + +impl FuncCallError { + pub fn get_with_context(context: String) -> Self { + Self { + context, + message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } + .to_str() + .expect("double fault") + .to_owned(), + } + } +} diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs new file mode 100644 index 000000000000..bd3ad415bab0 --- /dev/null +++ b/rust/tvm-sys/src/lib.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This crate contains the minimal interface over TVM's +//! C runtime API. +//! +//! These common bindings are useful to both runtimes +//! written in Rust, as well as higher level API bindings. +//! +//! See the `tvm-rt` or `tvm` crates for full bindings to +//! the TVM API. + +/// The low-level C runtime FFI API for TVM. +pub mod ffi { + #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] + + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); + + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; +} + +pub mod array; +pub mod byte_array; +pub mod context; +pub mod datatype; +pub mod errors; +#[macro_use] +pub mod packed_func; +pub mod value; + +pub use byte_array::ByteArray; +pub use context::{Context, TVMDeviceType}; +pub use datatype::DataType; +pub use errors::*; +pub use packed_func::{ArgValue, RetValue}; diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs new file mode 100644 index 000000000000..75d37b61beb4 --- /dev/null +++ b/rust/tvm-sys/src/packed_func.rs @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + convert::TryFrom, + ffi::{CStr, CString}, + os::raw::c_void, +}; + +pub use crate::ffi::TVMValue; +use crate::{errors::ValueDowncastError, ffi::*}; + +pub trait PackedFunc: + Fn(&[ArgValue]) -> Result + Send + Sync +{ +} + +impl PackedFunc for T where + T: Fn(&[ArgValue]) -> Result + Send + Sync +{ +} + +/// Calls a packed function and returns a `RetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// Constructs a derivative of a TVMPodValue. +macro_rules! TVMPODValue { + { + $(#[$m:meta])+ + $name:ident $(<$a:lifetime>)? { + $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? + }, + match $value:ident { + $($tvm_type:ident => { $from_tvm_type:expr })+ + }, + match &self { + $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ + } + $(,)? + } => { + $(#[$m])+ + #[derive(Clone, Debug)] + pub enum $name $(<$a>)? { + Int(i64), + UInt(i64), + Float(f64), + Null, + DataType(DLDataType), + String(CString), + Context(TVMContext), + Handle(*mut c_void), + ArrayHandle(TVMArrayHandle), + ObjectHandle(*mut c_void), + ModuleHandle(TVMModuleHandle), + FuncHandle(TVMFunctionHandle), + NDArrayHandle(*mut c_void), + $($extra_variant($variant_type)),+ + } + + impl $(<$a>)? $name $(<$a>)? { + pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { + use $name::*; + #[allow(non_upper_case_globals)] + unsafe { + match type_code as _ { + DLDataTypeCode_kDLInt => Int($value.v_int64), + DLDataTypeCode_kDLUInt => UInt($value.v_int64), + DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMTypeCode_kTVMNullptr => Null, + TVMTypeCode_kTVMDataType => DataType($value.v_type), + TVMTypeCode_kTVMContext => Context($value.v_ctx), + TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + $( $tvm_type => { $from_tvm_type } ),+ + _ => unimplemented!("{}", type_code), + } + } + } + + pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + use $name::*; + match self { + Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), + UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), + Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + String(val) => { + ( + TVMValue { v_handle: val.as_ptr() as *mut c_void }, + TVMTypeCode_kTVMStr, + ) + } + Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + ArrayHandle(val) => { + ( + TVMValue { v_handle: *val as *const _ as *mut c_void }, + TVMTypeCode_kTVMNDArrayHandle, + ) + }, + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ModuleHandle(val) => + (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + FuncHandle(val) => ( + TVMValue { v_handle: *val }, + TVMTypeCode_kTVMPackedFuncHandle + ), + NDArrayHandle(val) => + (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + $( $self_type($val) => { $from_self_type } ),+ + } + } + } + } +} + +TVMPODValue! { + /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way + /// to obtain a `ArgValue` is automatically via `call_packed!`. + ArgValue<'a> { + Bytes(&'a TVMByteArray), + Str(&'a CStr), + }, + match value { + TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + }, + match &self { + Bytes(val) => { + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + } +} + +TVMPODValue! { + /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. + /// Can be downcasted using `try_from` if it contains the desired type. + /// + /// # Example + /// + /// ``` + /// use std::convert::{TryFrom, TryInto}; + /// use tvm_sys::RetValue; + /// + /// let a = 42u32; + /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap(); + /// + /// let s = "hello, world!"; + /// let t: RetValue = s.to_string().into(); + /// assert_eq!(String::try_from(t).unwrap(), s); + /// ``` + RetValue { + Bytes(TVMByteArray), + Str(&'static CStr), + }, + match value { + TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + }, + match &self { + Bytes(val) => + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + Str(val) => + { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + } +} + +#[macro_export] +macro_rules! try_downcast { + ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { + match $val { + $( $pat => { Ok($converter) } )+ + _ => Err($crate::errors::ValueDowncastError { + actual_type: format!("{:?}", $val), + expected_type: stringify!($into), + }), + } + }; +} + +/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_pod_value { + ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { + $( + impl<'a> From<$type> for ArgValue<'a> { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + fn from(val: &'a $type) -> Self { + Self::$variant(*val as $inner_ty) + } + } + + impl<'a> TryFrom> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl TryFrom for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type }) + } + } + )+ + }; +} + +impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); +impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); +impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(DataType, DLDataType, [DLDataType]); +impl_pod_value!(Context, TVMContext, [TVMContext]); + +impl<'a> From<&'a str> for ArgValue<'a> { + fn from(s: &'a str) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(s: String) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + +impl<'a> From<&'a CStr> for ArgValue<'a> { + fn from(s: &'a CStr) -> Self { + Self::Str(s) + } +} + + +impl<'a> From for ArgValue<'a> { + fn from(s: CString) -> Self { + Self::String(s) + } +} + +impl<'a> From<&'a TVMByteArray> for ArgValue<'a> { + fn from(s: &'a TVMByteArray) -> Self { + Self::Bytes(s) + } +} + +impl<'a> TryFrom> for &'a str { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) + } +} + +/// Converts an unspecialized handle to a ArgValue. +impl From<*const T> for ArgValue<'static> { + fn from(ptr: *const T) -> Self { + Self::Handle(ptr as *mut c_void) + } +} + +/// Converts an unspecialized mutable handle to a ArgValue. +impl From<*mut T> for ArgValue<'static> { + fn from(ptr: *mut T) -> Self { + Self::Handle(ptr as *mut c_void) + } +} + +impl<'a> From<&'a mut DLTensor> for ArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + Self::ArrayHandle(arr as *mut DLTensor) + } +} + +impl<'a> From<&'a DLTensor> for ArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + Self::ArrayHandle(arr as *const _ as *mut DLTensor) + } +} + +impl TryFrom for String { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!( + val -> String, + |RetValue::String(s)| { s.into_string().unwrap() }, + |RetValue::Str(s)| { s.to_str().unwrap().to_string() } + ) + } +} + +impl From for RetValue { + fn from(s: String) -> Self { + Self::String(std::ffi::CString::new(s).unwrap()) + } +} + +impl From for RetValue { + fn from(arr: TVMByteArray) -> Self { + Self::Bytes(arr) + } +} + +impl TryFrom for TVMByteArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val }) + } +} + +impl Default for RetValue { + fn default() -> Self { + Self::Int(0) + } +} + +impl TryFrom for std::ffi::CString { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> std::ffi::CString, + |RetValue::Str(val)| { val.into() }) + } +} diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs new file mode 100644 index 000000000000..a468e43d4528 --- /dev/null +++ b/rust/tvm-sys/src/value.rs @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::str::FromStr; + +use thiserror::Error; + +use crate::ffi::*; + +macro_rules! impl_pod_tvm_value { + ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { + $( + impl From<$ty> for TVMValue { + fn from(val: $ty) -> Self { + TVMValue { $field: val as $field_ty } + } + } + + impl From for $ty { + fn from(val: TVMValue) -> Self { + unsafe { val.$field as $ty } + } + } + )+ + }; + ($field:ident, $ty:ty) => { + impl_pod_tvm_value!($field, $ty, $ty); + } +} + +impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); +impl_pod_tvm_value!(v_float64, f64, f32, f64); +impl_pod_tvm_value!(v_type, DLDataType); +impl_pod_tvm_value!(v_ctx, TVMContext); + +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] +pub struct UnsupportedDeviceError(String); + +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for TVMContext { + type Err = UnsupportedDeviceError; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type ),+, + _ => return Err(UnsupportedDeviceError(type_str.to_string())), + }, + device_id: 0, + }) + } + } + + impl TVMContext { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type, + device_id: device_id as i32, + } + } + )+ + )+ + } + }; +} + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +);