Skip to content

Commit

Permalink
Add TensorBase::try_from_data
Browse files Browse the repository at this point in the history
This is a fallible version of `from_data` which returns an error instead of
panicking of the data length and shape don't match. This is expected to be
useful eg. in Ocrs's WebAssembly API.
  • Loading branch information
robertknight committed Jan 31, 2024
1 parent 8517dcc commit 1881790
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
5 changes: 5 additions & 0 deletions rten-tensor/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub enum FromDataError {
/// Some indices will map to offsets that are beyond the end of the storage.
StorageTooShort,

/// The storage length was expected to exactly match the product of the
/// shape, and it did not.
StorageLengthMismatch,

/// Some indices will map to the same offset within the storage.
///
/// This error can only occur when the storage is mutable.
Expand All @@ -29,6 +33,7 @@ impl Display for FromDataError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
FromDataError::StorageTooShort => write!(f, "Data too short"),
FromDataError::StorageLengthMismatch => write!(f, "Data length mismatch"),
FromDataError::MayOverlap => write!(f, "May have internal overlap"),
}
}
Expand Down
51 changes: 42 additions & 9 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,38 @@ pub trait AsView: Layout {

impl<T, S: AsRef<[T]>, L: MutLayout> TensorBase<T, S, L> {
/// Construct a new tensor from a given shape and storage.
pub fn from_data(shape: L::Index<'_>, data: S) -> TensorBase<T, S, L> {
///
/// Panics if the data length does not match the product of `shape`.
pub fn from_data(shape: L::Index<'_>, data: S) -> TensorBase<T, S, L>
where
for<'a> L::Index<'a>: Clone,
{
let len = data.as_ref().len();
Self::try_from_data(shape.clone(), data).unwrap_or_else(|_| {
panic!(
"data length {} does not match shape {:?}",
len,
shape.as_ref(),
);
})
}

/// Construct a new tensor from a given shape and storage.
///
/// This will fail if the data length does not match the product of `shape`.
pub fn try_from_data(
shape: L::Index<'_>,
data: S,
) -> Result<TensorBase<T, S, L>, FromDataError> {
let layout = L::from_shape(shape);
assert!(
data.as_ref().len() == layout.len(),
"data length {} does not match shape {:?}",
data.as_ref().len(),
layout.shape().as_ref(),
);
TensorBase {
if layout.min_data_len() != data.as_ref().len() {
return Err(FromDataError::StorageLengthMismatch);
}
Ok(TensorBase {
data,
layout,
element_type: PhantomData,
}
})
}

/// Construct a new tensor from a given shape and storage, and custom
Expand Down Expand Up @@ -2452,6 +2471,20 @@ mod tests {
assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 6.]);
}

#[test]
fn test_try_from_data() {
let x = NdTensor::try_from_data([1, 2, 2], vec![1, 2, 3, 4]);
assert!(x.is_ok());
if let Ok(x) = x {
assert_eq!(x.shape(), [1, 2, 2]);
assert_eq!(x.strides(), [4, 2, 1]);
assert_eq!(x.to_vec(), [1, 2, 3, 4]);
}

let x = NdTensor::try_from_data([1, 2, 2], vec![1]);
assert_eq!(x, Err(FromDataError::StorageLengthMismatch));
}

#[test]
fn test_try_slice() {
let data = vec![1., 2., 3., 4.];
Expand Down

0 comments on commit 1881790

Please sign in to comment.