Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast the vector from deduced type to desired type if needed. #5416

Merged
merged 8 commits into from
Sep 17, 2024
84 changes: 54 additions & 30 deletions extension/tensor/tensor_impl_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,57 @@ inline TensorImplPtr make_tensor_impl_ptr(
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as a vector. The scalar type is automatically deduced from the vector's data
* type. The deleter ensures that the data vector is properly managed, with its
* lifetime tied to the TensorImpl.
* as a vector. If the specified `type` differs from the deduced type of the
* vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::vector<T> data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
if (type != deduced_type) {
ET_CHECK_MSG(
runtime::canCast(deduced_type, type),
"Cannot cast deduced type to specified type.");
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
ET_SWITCH_REALHBBF16_TYPES(
type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] {
std::transform(
data.begin(),
data.end(),
reinterpret_cast<CTYPE*>(casted_data.data()),
[](const T& val) { return static_cast<CTYPE>(val); });
});
const auto raw_data_ptr = casted_data.data();
auto data_ptr =
std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
return make_tensor_impl_ptr(
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
Expand All @@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr(
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type. The deleter ensures that the data vector is properly
* managed and its lifetime is tied to the TensorImpl.
* This template overload is specialized for cases where tensor data is provided
* as a vector. If the specified `type` differs from the deduced type of the
* vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param data A vector containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
Expand All @@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
Expand All @@ -168,17 +195,19 @@ inline TensorImplPtr make_tensor_impl_ptr(
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as an initializer list. The scalar type is automatically deduced from the
* initializer list's data type. The deleter ensures that the data is properly
* managed, with its lifetime tied to the TensorImpl.
* as an initializer list. If the specified `type` differs from the deduced type
* of the initializer list's elements, and casting is allowed, the data will be
* cast to the specified `type`. This allows for flexible creation of tensors
* with data initializer list of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
Expand All @@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
auto data = std::vector<T>(std::move(list));
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
std::move(sizes),
raw_data_ptr,
std::vector<T>(std::move(list)),
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
dynamism);
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as an initializer list. The scalar type is automatically deduced
* from the initializer list's data type. The deleter ensures that the data is
* properly managed and its lifetime is tied to the TensorImpl.
* This template overload is specialized for cases where tensor data is provided
* as an initializer list. If the specified `type` differs from the deduced type
* of the initializer list's elements, and casting is allowed, the data will be
* cast to the specified `type`. This allows for flexible creation of tensors
* with data initializer list of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
Expand All @@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
Expand Down
32 changes: 26 additions & 6 deletions extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,18 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type.
* vector's data type. If the specified `type` differs from the deduced type of
* the vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand Down Expand Up @@ -228,10 +232,15 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type.
* vector's data type. If the specified `type` differs from the deduced type of
* the vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param data A vector containing the tensor's data.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand All @@ -251,15 +260,20 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload is specialized for cases where the tensor data is
* provided as an initializer list. The scalar type is automatically deduced
* from the initializer list's data type.
* from the initializer list's data type. If the specified `type` differs from
* the deduced type of the initializer list's elements, and casting is allowed,
* the data will be cast to the specified `type`. This allows for flexible
* creation of tensors with data vectors of one type and a different scalar
* type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand Down Expand Up @@ -288,11 +302,17 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload allows creating a Tensor from an initializer list
* of data. The scalar type is automatically deduced from the type of the
* initializer list's elements.
* initializer list's elements. If the specified `type` differs from
* the deduced type of the initializer list's elements, and casting is allowed,
* the data will be cast to the specified `type`. This allows for flexible
* creation of tensors with data vectors of one type and a different scalar
* type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param list An initializer list containing the tensor's data.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand Down
125 changes: 125 additions & 0 deletions extension/tensor/test/tensor_impl_ptr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,128 @@ TEST_F(TensorImplPtrTest, StridesAndDimOrderMustMatchSizes) {
ET_EXPECT_DEATH(
{ auto _ = make_tensor_impl_ptr({3, 4}, data, {0}, {4, 1}); }, "");
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToFloat) {
std::vector<int32_t> int_data = {1, 2, 3, 4, 5, 6};
auto tensor_impl = make_tensor_impl_ptr(
{2, 3}, std::move(int_data), {}, {}, exec_aten::ScalarType::Float);

EXPECT_EQ(tensor_impl->dim(), 2);
EXPECT_EQ(tensor_impl->size(0), 2);
EXPECT_EQ(tensor_impl->size(1), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float);

auto data_ptr = static_cast<const float*>(tensor_impl->data());
EXPECT_FLOAT_EQ(data_ptr[0], 1.0f);
EXPECT_FLOAT_EQ(data_ptr[5], 6.0f);
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToDouble) {
std::vector<int32_t> int_data = {1, 2, 3};
auto tensor_impl =
make_tensor_impl_ptr(std::move(int_data), exec_aten::ScalarType::Double);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Double);

auto data_ptr = static_cast<const double*>(tensor_impl->data());
EXPECT_DOUBLE_EQ(data_ptr[0], 1.0);
EXPECT_DOUBLE_EQ(data_ptr[1], 2.0);
EXPECT_DOUBLE_EQ(data_ptr[2], 3.0);
}

TEST_F(TensorImplPtrTest, TensorDataCastingInvalidCast) {
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
ET_EXPECT_DEATH(
{
auto _ = make_tensor_impl_ptr(
std::move(float_data), exec_aten::ScalarType::Int);
},
"");
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToHalf) {
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
auto tensor_impl =
make_tensor_impl_ptr(std::move(float_data), exec_aten::ScalarType::Half);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half);

auto data_ptr = static_cast<const exec_aten::Half*>(tensor_impl->data());
EXPECT_EQ(static_cast<float>(data_ptr[0]), 1.0f);
EXPECT_EQ(static_cast<float>(data_ptr[1]), 2.0f);
EXPECT_EQ(static_cast<float>(data_ptr[2]), 3.0f);
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromDoubleToFloat) {
std::vector<double> double_data = {1.1, 2.2, 3.3};
auto tensor_impl = make_tensor_impl_ptr(
std::move(double_data), exec_aten::ScalarType::Float);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float);

auto data_ptr = static_cast<const float*>(tensor_impl->data());
EXPECT_FLOAT_EQ(data_ptr[0], 1.1f);
EXPECT_FLOAT_EQ(data_ptr[1], 2.2f);
EXPECT_FLOAT_EQ(data_ptr[2], 3.3f);
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromInt64ToInt32) {
std::vector<int64_t> int64_data = {10000000000, 20000000000, 30000000000};
auto tensor_impl =
make_tensor_impl_ptr(std::move(int64_data), exec_aten::ScalarType::Int);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Int);

auto data_ptr = static_cast<const int32_t*>(tensor_impl->data());
// Since the values exceed int32_t range, they may overflow
// Here we just check that the cast was performed
EXPECT_NE(data_ptr[0], 10000000000); // Expected overflow
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToBFloat16) {
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
auto tensor_impl = make_tensor_impl_ptr(
std::move(float_data), exec_aten::ScalarType::BFloat16);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::BFloat16);

auto data_ptr = static_cast<const exec_aten::BFloat16*>(tensor_impl->data());
EXPECT_EQ(static_cast<float>(data_ptr[0]), 1.0f);
EXPECT_EQ(static_cast<float>(data_ptr[1]), 2.0f);
EXPECT_EQ(static_cast<float>(data_ptr[2]), 3.0f);
}

TEST_F(TensorImplPtrTest, InitializerListDoubleToHalf) {
auto tensor_impl = make_tensor_impl_ptr<double>(
{1.5, 2.7, 3.14}, exec_aten::ScalarType::Half);
EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half);
auto data_ptr = static_cast<const exec_aten::Half*>(tensor_impl->data());
EXPECT_NEAR(static_cast<float>(data_ptr[0]), 1.5f, 0.01);
EXPECT_NEAR(static_cast<float>(data_ptr[1]), 2.7f, 0.01);
EXPECT_NEAR(static_cast<float>(data_ptr[2]), 3.14f, 0.01);
}

TEST_F(TensorImplPtrTest, InitializerListInt8ToInt64) {
auto tensor_impl =
make_tensor_impl_ptr<int8_t>({1, -2, 3, -4}, exec_aten::ScalarType::Long);
EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 4);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Long);
auto data_ptr = static_cast<const int64_t*>(tensor_impl->data());
EXPECT_EQ(data_ptr[0], 1);
EXPECT_EQ(data_ptr[1], -2);
EXPECT_EQ(data_ptr[2], 3);
EXPECT_EQ(data_ptr[3], -4);
}
Loading
Loading