Skip to content

Commit

Permalink
Update Rust code to work with newest runtime API
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Apr 12, 2020
1 parent 677ad6b commit 27a94ae
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 37 deletions.
6 changes: 3 additions & 3 deletions rust/common/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ macro_rules! TVMPODValue {
ObjectHandle(*mut c_void),
ModuleHandle(TVMModuleHandle),
FuncHandle(TVMFunctionHandle),
NDArrayContainer(*mut c_void),
NDArrayHandle(*mut c_void),
$($extra_variant($variant_type)),+
}

Expand All @@ -102,7 +102,7 @@ macro_rules! TVMPODValue {
TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code),
}
Expand Down Expand Up @@ -138,7 +138,7 @@ macro_rules! TVMPODValue {
TVMValue { v_handle: *val },
TVMTypeCode_kTVMPackedFuncHandle
),
NDArrayContainer(val) =>
NDArrayHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+
}
Expand Down
5 changes: 4 additions & 1 deletion rust/frontend/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _));
check_call!(ffi::TVMCbArgToReturn(
&mut value as *mut _,
&mut tcode as *mut _
));
}
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
}
Expand Down
74 changes: 48 additions & 26 deletions rust/frontend/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use failure::Error;
use num_traits::Num;
use rust_ndarray::{Array, ArrayD};
use std::convert::TryInto;
use std::ffi::c_void;
use tvm_common::ffi::DLTensor;
use tvm_common::{ffi, TVMType};

use crate::{errors, TVMByteArray, TVMContext};
Expand All @@ -60,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
pub struct NDArray {
pub(crate) handle: ffi::TVMArrayHandle,
is_view: bool,
pub enum NDArray {
Borrowed { handle: ffi::TVMArrayHandle },
Owned { handle: *mut c_void },
}

impl NDArray {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray {
handle,
is_view: true,
NDArray::Borrowed { handle }
}

pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
NDArray::Owned { handle }
}

pub fn as_dltensor(&self) -> &DLTensor {
unsafe {
match self {
NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
NDArray::Owned { ref handle } => std::mem::transmute(*handle),
}
}
}

/// Returns the underlying array handle.
pub fn handle(&self) -> ffi::TVMArrayHandle {
self.handle
pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
unsafe {
match self {
NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
NDArray::Owned { ref handle } => std::mem::transmute(*handle),
}
}
}

pub fn is_view(&self) -> bool {
self.is_view
if let &NDArray::Borrowed { .. } = self {
true
} else {
false
}
}

/// Returns the shape of the NDArray.
pub fn shape(&self) -> Option<&mut [usize]> {
let arr = unsafe { *(self.handle) };
let arr = self.as_dltensor();
if arr.shape.is_null() || arr.data.is_null() {
return None;
};
Expand All @@ -99,24 +120,28 @@ impl NDArray {

/// Returns the context which the NDArray was defined.
pub fn ctx(&self) -> TVMContext {
unsafe { (*self.handle).ctx.into() }
self.as_dltensor().ctx.into()
}

/// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType {
unsafe { (*self.handle).dtype }
self.as_dltensor().dtype
}

/// Returns the number of dimensions of the NDArray.
pub fn ndim(&self) -> usize {
unsafe { (*self.handle).ndim as usize }
self.as_dltensor()
.ndim
.try_into()
.expect("number of dimensions must always be positive")
}

/// Returns the strides of the underlying NDArray.
pub fn strides(&self) -> Option<&[usize]> {
unsafe {
let sz = self.ndim() * mem::size_of::<usize>();
let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
let strides_ptr = self.as_dltensor().strides as *const usize;
let slc = slice::from_raw_parts(strides_ptr, sz);
Some(slc)
}
}
Expand Down Expand Up @@ -146,7 +171,7 @@ impl NDArray {
}

pub fn byte_offset(&self) -> isize {
unsafe { (*self.handle).byte_offset as isize }
self.as_dltensor().byte_offset as isize
}

/// Flattens the NDArray to a `Vec` of the same type in cpu.
Expand All @@ -172,7 +197,7 @@ impl NDArray {
self.dtype(),
);
let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) };
let arr = target.as_dltensor();
let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe {
Expand Down Expand Up @@ -207,7 +232,7 @@ impl NDArray {
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ffi::TVMArrayCopyFromBytes(
self.handle,
self.as_raw_dltensor(),
data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>()
));
Expand All @@ -225,8 +250,8 @@ impl NDArray {
);
}
check_call!(ffi::TVMArrayCopyFromTo(
self.handle,
target.handle,
self.as_raw_dltensor(),
target.as_raw_dltensor(),
ptr::null_mut() as ffi::TVMStreamHandle
));
Ok(target)
Expand Down Expand Up @@ -272,10 +297,7 @@ impl NDArray {
ctx.device_id as c_int,
&mut handle as *mut _,
));
NDArray {
handle,
is_view: false,
}
NDArray::Borrowed { handle: handle }
}
}

Expand Down Expand Up @@ -313,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");

impl Drop for NDArray {
fn drop(&mut self) {
if !self.is_view {
check_call!(ffi::TVMArrayFree(self.handle));
if let &mut NDArray::Owned { .. } = self {
check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
}
}
}
Expand Down
63 changes: 57 additions & 6 deletions rust/frontend/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
//! `TVMRetValue` is the owned version of `TVMPODValue`.

use std::convert::TryFrom;
// use std::ffi::c_void;

use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
use tvm_common::{
errors::ValueDowncastError,
ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
ffi::{TVMFunctionHandle, TVMModuleHandle},
try_downcast,
};

use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};

macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
impl<'a> From<&'a $type> for TVMArgValue<'a> {
Expand Down Expand Up @@ -76,9 +76,60 @@ macro_rules! impl_handle_val {

impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
// TODO(@jroesch): introduce NDArray handle on C++ side.
// impl_handle_val!(NDArray, NDArrayHandle, TVMObjectHandle, NDArray::new)

impl<'a> From<&'a NDArray> for TVMArgValue<'a> {
fn from(arg: &'a NDArray) -> Self {
match arg {
&NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
&NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
}
}
}

impl<'a> From<&'a mut NDArray> for TVMArgValue<'a> {
fn from(arg: &'a mut NDArray) -> Self {
match arg {
&mut NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
&mut NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
}
}
}

impl<'a> TryFrom<TVMArgValue<'a>> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
|TVMArgValue::ArrayHandle(val)| { NDArray::new(val) })
}
}

impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) },
|TVMArgValue::ArrayHandle(val)| { NDArray::new(*val) })
}
}

impl From<NDArray> for TVMRetValue {
fn from(val: NDArray) -> TVMRetValue {
match val {
NDArray::Owned { handle } => TVMRetValue::NDArrayHandle(handle),
_ => panic!("NYI"),
}
}
}

impl TryFrom<TVMRetValue> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMRetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
|TVMRetValue::ArrayHandle(val)| { NDArray::new(val) })
}
}

#[cfg(test)]
mod tests {
Expand Down
2 changes: 1 addition & 1 deletion rust/frontend/tests/callback/src/bin/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ fn main() {
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, 14f32);
assert_eq!(ret, 7f32);
}

0 comments on commit 27a94ae

Please sign in to comment.