Skip to content

Commit

Permalink
update traits
Browse files Browse the repository at this point in the history
  • Loading branch information
SunDoge committed May 12, 2023
1 parent 9b45714 commit 75d6ea6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
15 changes: 6 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,30 @@ impl HasData for PyRgbImage {
}

impl HasDevice for PyRgbImage {
fn device(&self) -> dlpark::ffi::Device {
fn device(&self) -> Device {
Device::CPU
}
}

impl HasDtype for PyRgbImage {
fn dtype(&self) -> dlpark::ffi::DataType {
fn dtype(&self) -> DataType {
DataType::U8
}
}

impl HasShape for PyRgbImage {
fn shape(&self) -> dlpark::tensor::Shape {
dlpark::tensor::Shape::Owned(
fn shape(&self) -> Shape {
Shape::Owned(
[self.0.height(), self.0.width(), 3]
.map(|x| x as i64)
.to_vec(),
)
}
}

// Strides can be infered from Shape since it's compact and row-majored.
impl HasStrides for PyRgbImage {}
impl HasByteOffset for PyRgbImage {
fn byte_offset(&self) -> u64 {
0
}
}
impl HasByteOffset for PyRgbImage {}
```

Then we can return a `ManagerCtx<PyRgbImage>`
Expand Down
4 changes: 4 additions & 0 deletions src/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@ impl DataType {
bits: 8,
lanes: 1,
};

pub fn size(&self) -> usize {
((self.bits as u32 * self.lanes as u32 + 7) / 8) as usize
}
}
5 changes: 2 additions & 3 deletions src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
pub mod impls;
pub mod traits;

use std::ffi::c_void;

use traits::{HasByteOffset, HasData, HasDevice, HasDtype, HasShape, HasStrides};

use crate::ffi::{self, DataType, Device};
Expand Down Expand Up @@ -53,6 +51,7 @@ impl Shape {
}
}

/// If it is borrowed, the length should be `tensor.ndim()`
#[derive(Debug)]
pub enum Strides {
Borrowed(*mut i64),
Expand Down Expand Up @@ -149,7 +148,7 @@ where

impl AsTensor for ffi::DLTensor {
fn data<T>(&self) -> *const T {
self.data as *const c_void as *const T
self.data as *const T
}

fn shape(&self) -> &[i64] {
Expand Down
25 changes: 23 additions & 2 deletions src/tensor/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait HasShape {
fn shape(&self) -> Shape;
}

/// Can be `None`, indicating tensor is compact and row-majored.
pub trait HasStrides {
fn strides(&self) -> Option<Strides> {
None
Expand Down Expand Up @@ -41,10 +42,30 @@ pub trait AsTensor {
fn ndim(&self) -> usize;
fn device(&self) -> Device;
fn dtype(&self) -> DataType;
fn byte_offset(&self) -> u64;

fn byte_offset(&self) -> u64 {
0
}

fn num_elements(&self) -> usize {
self.shape().iter().fold(1usize, |acc, &x| acc * x as usize)
self.shape().iter().product::<i64>() as usize
}

/// For given DLTensor, the size of memory required to store the contents of
/// data is calculated as follows:
///
/// ```c
/// static inline size_t GetDataSize(const DLTensor* t) {
/// size_t size = 1;
/// for (tvm_index_t i = 0; i < t->ndim; ++i) {
/// size *= t->shape[i];
/// }
/// size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
/// return size;
/// }
/// ```
fn data_size(&self) -> usize {
self.num_elements() * self.dtype().size()
}
}

Expand Down

0 comments on commit 75d6ea6

Please sign in to comment.