From 75d6ea6a044dc68b31f4fb5d8453e98f0abb385a Mon Sep 17 00:00:00 2001 From: SunDoge <384813529@qq.com> Date: Fri, 12 May 2023 11:29:17 +0800 Subject: [PATCH] update traits --- README.md | 15 ++++++--------- src/data_type.rs | 4 ++++ src/tensor.rs | 5 ++--- src/tensor/traits.rs | 25 +++++++++++++++++++++++-- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 487ab1b..38e0888 100644 --- a/README.md +++ b/README.md @@ -31,20 +31,20 @@ impl HasData for PyRgbImage { } impl HasDevice for PyRgbImage { - fn device(&self) -> dlpark::ffi::Device { + fn device(&self) -> Device { Device::CPU } } impl HasDtype for PyRgbImage { - fn dtype(&self) -> dlpark::ffi::DataType { + fn dtype(&self) -> DataType { DataType::U8 } } impl HasShape for PyRgbImage { - fn shape(&self) -> dlpark::tensor::Shape { - dlpark::tensor::Shape::Owned( + fn shape(&self) -> Shape { + Shape::Owned( [self.0.height(), self.0.width(), 3] .map(|x| x as i64) .to_vec(), @@ -52,12 +52,9 @@ impl HasShape for PyRgbImage { } } +// Strides can be infered from Shape since it's compact and row-majored. impl HasStrides for PyRgbImage {} -impl HasByteOffset for PyRgbImage { - fn byte_offset(&self) -> u64 { - 0 - } -} +impl HasByteOffset for PyRgbImage {} ``` Then we can return a `ManagerCtx` diff --git a/src/data_type.rs b/src/data_type.rs index 5c35448..c0e0bbb 100644 --- a/src/data_type.rs +++ b/src/data_type.rs @@ -57,4 +57,8 @@ impl DataType { bits: 8, lanes: 1, }; + + pub fn size(&self) -> usize { + ((self.bits as u32 * self.lanes as u32 + 7) / 8) as usize + } } diff --git a/src/tensor.rs b/src/tensor.rs index cd3f977..4671b0c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,8 +1,6 @@ pub mod impls; pub mod traits; -use std::ffi::c_void; - use traits::{HasByteOffset, HasData, HasDevice, HasDtype, HasShape, HasStrides}; use crate::ffi::{self, DataType, Device}; @@ -53,6 +51,7 @@ impl Shape { } } +/// If it is borrowed, the length should be `tensor.ndim()` #[derive(Debug)] pub enum Strides { Borrowed(*mut i64), @@ -149,7 +148,7 @@ where impl AsTensor for ffi::DLTensor { fn data(&self) -> *const T { - self.data as *const c_void as *const T + self.data as *const T } fn shape(&self) -> &[i64] { diff --git a/src/tensor/traits.rs b/src/tensor/traits.rs index 7b0ad41..f19d7cf 100644 --- a/src/tensor/traits.rs +++ b/src/tensor/traits.rs @@ -12,6 +12,7 @@ pub trait HasShape { fn shape(&self) -> Shape; } +/// Can be `None`, indicating tensor is compact and row-majored. pub trait HasStrides { fn strides(&self) -> Option { None @@ -41,10 +42,30 @@ pub trait AsTensor { fn ndim(&self) -> usize; fn device(&self) -> Device; fn dtype(&self) -> DataType; - fn byte_offset(&self) -> u64; + + fn byte_offset(&self) -> u64 { + 0 + } fn num_elements(&self) -> usize { - self.shape().iter().fold(1usize, |acc, &x| acc * x as usize) + self.shape().iter().product::() as usize + } + + /// For given DLTensor, the size of memory required to store the contents of + /// data is calculated as follows: + /// + /// ```c + /// static inline size_t GetDataSize(const DLTensor* t) { + /// size_t size = 1; + /// for (tvm_index_t i = 0; i < t->ndim; ++i) { + /// size *= t->shape[i]; + /// } + /// size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + /// return size; + /// } + /// ``` + fn data_size(&self) -> usize { + self.num_elements() * self.dtype().size() } }