Skip to content

Commit

Permalink
Merge pull request #388 from robertknight/faster-noop-cast
Browse files Browse the repository at this point in the history
Optimize Cast op when source and dest types are the same, add u8/i8 cast tests
  • Loading branch information
robertknight authored Oct 17, 2024
2 parents b76661c + 59b1c5b commit bebe8f3
Showing 1 changed file with 77 additions and 82 deletions.
159 changes: 77 additions & 82 deletions src/ops/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@ use crate::tensor_pool::TensorPool;
fn cast(pool: &TensorPool, input: Input, dtype: DataType) -> Result<Output, OpError> {
match dtype {
DataType::Int32 => match input {
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::Int32Tensor(t) => Ok(t.to_tensor_in(pool).into()),
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
},
DataType::Float => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::FloatTensor(t) => Ok(t.to_tensor_in(pool).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
},
DataType::Int8 => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::Int8Tensor(t) => Ok(t.to_tensor_in(pool).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
},
DataType::UInt8 => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::UInt8Tensor(t) => Ok(t.to_tensor_in(pool).into()),
},
}
}
Expand Down Expand Up @@ -60,6 +60,8 @@ impl Operator for Cast {
match (input, self.to) {
(Output::Int32Tensor(t), DataType::Int32) => Ok(t.into()),
(Output::FloatTensor(t), DataType::Float) => Ok(t.into()),
(Output::Int8Tensor(t), DataType::Int8) => Ok(t.into()),
(Output::UInt8Tensor(t), DataType::UInt8) => Ok(t.into()),
(input, _) => {
let converted = cast(pool, input.as_input(), self.to)?;
input.add_to_pool(pool);
Expand All @@ -73,93 +75,86 @@ impl Operator for Cast {
mod tests {
use std::error::Error;

use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;

use crate::ops::tests::new_pool;
use crate::ops::{Cast, DataType, Operator};
use crate::ops::{Cast, DataType, Operator, Output};

#[test]
fn test_cast() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let int_input = Tensor::from([1, 2, 3]);
let float_input = Tensor::from([1.0, 2.0, 3.0]);

// No-op cast from int32 => int32
let cast_to_int = Cast {
to: DataType::Int32,
};
let result = cast_to_int
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_tensor::<i32>()
.unwrap();

// Flooring cast from float => int32
assert_eq!(result, int_input);
let result = cast_to_int
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_tensor::<i32>()
.unwrap();
assert_eq!(&result, &int_input);

// No-op cast from float => float
let cast_to_float = Cast {
to: DataType::Float,
};
let result = cast_to_float
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &float_input)?;

// Cast from int32 => float
let result = cast_to_float
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &float_input)?;
struct Case {
input: Output,
dtype: DataType,
expected: Output,
}

Ok(())
}
let cases = [
// i32 -> f32
Case {
input: Tensor::from([1, 2, 3]).into(),
dtype: DataType::Float,
expected: Tensor::from([1., 2., 3.]).into(),
},
// i32 -> i32
Case {
input: Tensor::from([1, 2, 3]).into(),
dtype: DataType::Int32,
expected: Tensor::from([1, 2, 3]).into(),
},
// i32 -> i8
Case {
input: Tensor::from([i8::MIN as i32, 0, i8::MAX as i32]).into(),
dtype: DataType::Int8,
expected: Tensor::from([i8::MIN, 0, i8::MAX]).into(),
},
// i32 -> u8
Case {
input: Tensor::from([u8::MIN as i32, 0, u8::MAX as i32]).into(),
dtype: DataType::UInt8,
expected: Tensor::from([u8::MIN, 0, u8::MAX]).into(),
},
// f32 -> i32
Case {
input: Tensor::from([1., 2., 3.]).into(),
dtype: DataType::Int32,
expected: Tensor::from([1, 2, 3]).into(),
},
// f32 -> f32
Case {
input: Tensor::from([1., 2., 3.]).into(),
dtype: DataType::Float,
expected: Tensor::from([1., 2., 3.]).into(),
},
// Int -> float out of range. This will lose precision.
Case {
input: Tensor::from([i32::MIN, i32::MAX]).into(),
dtype: DataType::Float,
expected: Tensor::from([-2147483600.0, 2147483600.0]).into(),
},
// Float -> int out of range.
//
// In RTen this saturates following the behavior of Rust's `as`
// operator. This is different than C++ / PyTorch / NumPy where
// the behavior of such conversions is undefined.
// See https://github.com/robertknight/rten/pull/387#issuecomment-2420343989.
Case {
input: Tensor::from([f32::MIN, f32::MAX]).into(),
dtype: DataType::Int32,
expected: Tensor::from([i32::MIN, i32::MAX]).into(),
},
];

#[test]
fn test_cast_out_of_range() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let int_input = Tensor::from([i32::MIN, i32::MAX]);

// Out-of-range cast from int => float. This will simply lose some
// significant digits.
let cast_to_float = Cast {
to: DataType::Float,
};
let result = cast_to_float
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &Tensor::from([-2147483600.0, 2147483600.0]))?;

// Out-of-range cast from float => int.
let float_input = Tensor::from([f32::MIN, f32::MAX]);
let cast_to_int = Cast {
to: DataType::Int32,
};
let result = cast_to_int
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_tensor::<i32>()
.unwrap();
assert_eq!(&result, &Tensor::from([i32::MIN, i32::MAX]));
for Case {
input,
dtype,
expected,
} in cases
{
let cast_op = Cast { to: dtype };
let result = cast_op.run(&pool, (&input).into()).unwrap().remove(0);
assert_eq!(result, expected);
}

Ok(())
}
Expand Down

0 comments on commit bebe8f3

Please sign in to comment.