diff --git a/extension/tensor/tensor_impl_ptr.h b/extension/tensor/tensor_impl_ptr.h index 8857dd1aca..89fc7ff1eb 100644 --- a/extension/tensor/tensor_impl_ptr.h +++ b/extension/tensor/tensor_impl_ptr.h @@ -97,23 +97,25 @@ inline TensorImplPtr make_tensor_impl_ptr( * specified properties. * * This template overload is specialized for cases where tensor data is provided - * as a vector. The scalar type is automatically deduced from the vector's data - * type. The deleter ensures that the data vector is properly managed, with its - * lifetime tied to the TensorImpl. + * as a vector. If the specified `type` differs from the deduced type of the + * vector's elements, and casting is allowed, the data will be cast to the + * specified `type`. This allows for flexible creation of tensors with data + * vectors of one type and a different scalar type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param sizes A vector specifying the size of each dimension. * @param data A vector containing the tensor's data. * @param dim_order A vector specifying the order of dimensions. * @param strides A vector specifying the strides of each dimension. - * @param type The scalar type of the tensor elements. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorImplPtr that manages the newly created TensorImpl. */ template < typename T = float, exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType::value> -inline TensorImplPtr make_tensor_impl_ptr( +TensorImplPtr make_tensor_impl_ptr( std::vector sizes, std::vector data, std::vector dim_order = {}, @@ -121,7 +123,31 @@ inline TensorImplPtr make_tensor_impl_ptr( exec_aten::ScalarType type = deduced_type, exec_aten::TensorShapeDynamism dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { - ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type."); + if (type != deduced_type) { + ET_CHECK_MSG( + runtime::canCast(deduced_type, type), + "Cannot cast deduced type to specified type."); + std::vector casted_data(data.size() * runtime::elementSize(type)); + ET_SWITCH_REALHBBF16_TYPES( + type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] { + std::transform( + data.begin(), + data.end(), + reinterpret_cast(casted_data.data()), + [](const T& val) { return static_cast(val); }); + }); + const auto raw_data_ptr = casted_data.data(); + auto data_ptr = + std::make_shared>(std::move(casted_data)); + return make_tensor_impl_ptr( + std::move(sizes), + raw_data_ptr, + std::move(dim_order), + std::move(strides), + type, + dynamism, + [data_ptr = std::move(data_ptr)](void*) {}); + } const auto raw_data_ptr = data.data(); auto data_ptr = std::make_shared>(std::move(data)); return make_tensor_impl_ptr( @@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr( * Creates a TensorImplPtr that manages a newly created TensorImpl with the * specified properties. * - * This template overload is specialized for cases where the tensor data is - * provided as a vector. The scalar type is automatically deduced from the - * vector's data type. The deleter ensures that the data vector is properly - * managed and its lifetime is tied to the TensorImpl. + * This template overload is specialized for cases where tensor data is provided + * as a vector. If the specified `type` differs from the deduced type of the + * vector's elements, and casting is allowed, the data will be cast to the + * specified `type`. This allows for flexible creation of tensors with data + * vectors of one type and a different scalar type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param data A vector containing the tensor's data. - * @param type The scalar type of the tensor elements. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorImplPtr that manages the newly created TensorImpl. */ @@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr( exec_aten::ScalarType type = deduced_type, exec_aten::TensorShapeDynamism dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { - ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type."); std::vector sizes{exec_aten::SizesType(data.size())}; return make_tensor_impl_ptr( std::move(sizes), std::move(data), {0}, {1}, type, dynamism); @@ -168,9 +195,10 @@ inline TensorImplPtr make_tensor_impl_ptr( * specified properties. * * This template overload is specialized for cases where tensor data is provided - * as an initializer list. The scalar type is automatically deduced from the - * initializer list's data type. The deleter ensures that the data is properly - * managed, with its lifetime tied to the TensorImpl. + * as an initializer list. If the specified `type` differs from the deduced type + * of the initializer list's elements, and casting is allowed, the data will be + * cast to the specified `type`. This allows for flexible creation of tensors + * with data initializer list of one type and a different scalar type. * * @tparam T The C++ type of the tensor elements, deduced from the initializer * list. @@ -178,7 +206,8 @@ inline TensorImplPtr make_tensor_impl_ptr( * @param list An initializer list containing the tensor's data. * @param dim_order A vector specifying the order of dimensions. * @param strides A vector specifying the strides of each dimension. - * @param type The scalar type of the tensor elements. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorImplPtr that manages the newly created TensorImpl. */ @@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr( exec_aten::ScalarType type = deduced_type, exec_aten::TensorShapeDynamism dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { - ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type."); - auto data = std::vector(std::move(list)); - const auto raw_data_ptr = data.data(); - auto data_ptr = std::make_shared>(std::move(data)); return make_tensor_impl_ptr( std::move(sizes), - raw_data_ptr, + std::vector(std::move(list)), std::move(dim_order), std::move(strides), type, - dynamism, - [data_ptr = std::move(data_ptr)](void*) {}); + dynamism); } /** * Creates a TensorImplPtr that manages a newly created TensorImpl with the * specified properties. * - * This template overload is specialized for cases where the tensor data is - * provided as an initializer list. The scalar type is automatically deduced - * from the initializer list's data type. The deleter ensures that the data is - * properly managed and its lifetime is tied to the TensorImpl. + * This template overload is specialized for cases where tensor data is provided + * as an initializer list. If the specified `type` differs from the deduced type + * of the initializer list's elements, and casting is allowed, the data will be + * cast to the specified `type`. This allows for flexible creation of tensors + * with data initializer list of one type and a different scalar type. * * @tparam T The C++ type of the tensor elements, deduced from the initializer * list. - * @param sizes A vector specifying the size of each dimension. * @param list An initializer list containing the tensor's data. - * @param type The scalar type of the tensor elements. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorImplPtr that manages the newly created TensorImpl. */ @@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr( exec_aten::ScalarType type = deduced_type, exec_aten::TensorShapeDynamism dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { - ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type."); std::vector sizes{exec_aten::SizesType(list.size())}; return make_tensor_impl_ptr( std::move(sizes), std::move(list), {0}, {1}, type, dynamism); diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 41dc6282eb..e8a97be036 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -192,14 +192,18 @@ inline TensorPtr make_tensor_ptr( * * This template overload is specialized for cases where the tensor data is * provided as a vector. The scalar type is automatically deduced from the - * vector's data type. + * vector's data type. If the specified `type` differs from the deduced type of + * the vector's elements, and casting is allowed, the data will be cast to the + * specified `type`. This allows for flexible creation of tensors with data + * vectors of one type and a different scalar type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param sizes A vector specifying the size of each dimension. * @param data A vector containing the tensor's data. * @param dim_order A vector specifying the order of dimensions. * @param strides A vector specifying the strides of each dimension. - * @param type The scalar type of the tensor elements. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorPtr that manages the newly created TensorImpl. */ @@ -228,10 +232,15 @@ inline TensorPtr make_tensor_ptr( * * This template overload is specialized for cases where the tensor data is * provided as a vector. The scalar type is automatically deduced from the - * vector's data type. + * vector's data type. If the specified `type` differs from the deduced type of + * the vector's elements, and casting is allowed, the data will be cast to the + * specified `type`. This allows for flexible creation of tensors with data + * vectors of one type and a different scalar type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param data A vector containing the tensor's data. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorPtr that manages the newly created TensorImpl. */ @@ -251,7 +260,11 @@ inline TensorPtr make_tensor_ptr( * * This template overload is specialized for cases where the tensor data is * provided as an initializer list. The scalar type is automatically deduced - * from the initializer list's data type. + * from the initializer list's data type. If the specified `type` differs from + * the deduced type of the initializer list's elements, and casting is allowed, + * the data will be cast to the specified `type`. This allows for flexible + * creation of tensors with data vectors of one type and a different scalar + * type. * * @tparam T The C++ type of the tensor elements, deduced from the initializer * list. @@ -259,7 +272,8 @@ inline TensorPtr make_tensor_ptr( * @param list An initializer list containing the tensor's data. * @param dim_order A vector specifying the order of dimensions. * @param strides A vector specifying the strides of each dimension. - * @param type The scalar type of the tensor elements. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorPtr that manages the newly created TensorImpl. */ @@ -288,11 +302,17 @@ inline TensorPtr make_tensor_ptr( * * This template overload allows creating a Tensor from an initializer list * of data. The scalar type is automatically deduced from the type of the - * initializer list's elements. + * initializer list's elements. If the specified `type` differs from + * the deduced type of the initializer list's elements, and casting is allowed, + * the data will be cast to the specified `type`. This allows for flexible + * creation of tensors with data vectors of one type and a different scalar + * type. * * @tparam T The C++ type of the tensor elements, deduced from the initializer * list. * @param list An initializer list containing the tensor's data. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. * @param dynamism Specifies the mutability of the tensor's shape. * @return A TensorPtr that manages the newly created TensorImpl. */ diff --git a/extension/tensor/test/tensor_impl_ptr_test.cpp b/extension/tensor/test/tensor_impl_ptr_test.cpp index b345258c2c..d3d827a495 100644 --- a/extension/tensor/test/tensor_impl_ptr_test.cpp +++ b/extension/tensor/test/tensor_impl_ptr_test.cpp @@ -366,3 +366,128 @@ TEST_F(TensorImplPtrTest, StridesAndDimOrderMustMatchSizes) { ET_EXPECT_DEATH( { auto _ = make_tensor_impl_ptr({3, 4}, data, {0}, {4, 1}); }, ""); } + +TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToFloat) { + std::vector int_data = {1, 2, 3, 4, 5, 6}; + auto tensor_impl = make_tensor_impl_ptr( + {2, 3}, std::move(int_data), {}, {}, exec_aten::ScalarType::Float); + + EXPECT_EQ(tensor_impl->dim(), 2); + EXPECT_EQ(tensor_impl->size(0), 2); + EXPECT_EQ(tensor_impl->size(1), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float); + + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_FLOAT_EQ(data_ptr[0], 1.0f); + EXPECT_FLOAT_EQ(data_ptr[5], 6.0f); +} + +TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToDouble) { + std::vector int_data = {1, 2, 3}; + auto tensor_impl = + make_tensor_impl_ptr(std::move(int_data), exec_aten::ScalarType::Double); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Double); + + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_DOUBLE_EQ(data_ptr[0], 1.0); + EXPECT_DOUBLE_EQ(data_ptr[1], 2.0); + EXPECT_DOUBLE_EQ(data_ptr[2], 3.0); +} + +TEST_F(TensorImplPtrTest, TensorDataCastingInvalidCast) { + std::vector float_data = {1.0f, 2.0f, 3.0f}; + ET_EXPECT_DEATH( + { + auto _ = make_tensor_impl_ptr( + std::move(float_data), exec_aten::ScalarType::Int); + }, + ""); +} + +TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToHalf) { + std::vector float_data = {1.0f, 2.0f, 3.0f}; + auto tensor_impl = + make_tensor_impl_ptr(std::move(float_data), exec_aten::ScalarType::Half); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half); + + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_EQ(static_cast(data_ptr[0]), 1.0f); + EXPECT_EQ(static_cast(data_ptr[1]), 2.0f); + EXPECT_EQ(static_cast(data_ptr[2]), 3.0f); +} + +TEST_F(TensorImplPtrTest, TensorDataCastingFromDoubleToFloat) { + std::vector double_data = {1.1, 2.2, 3.3}; + auto tensor_impl = make_tensor_impl_ptr( + std::move(double_data), exec_aten::ScalarType::Float); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float); + + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_FLOAT_EQ(data_ptr[0], 1.1f); + EXPECT_FLOAT_EQ(data_ptr[1], 2.2f); + EXPECT_FLOAT_EQ(data_ptr[2], 3.3f); +} + +TEST_F(TensorImplPtrTest, TensorDataCastingFromInt64ToInt32) { + std::vector int64_data = {10000000000, 20000000000, 30000000000}; + auto tensor_impl = + make_tensor_impl_ptr(std::move(int64_data), exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Int); + + auto data_ptr = static_cast(tensor_impl->data()); + // Since the values exceed int32_t range, they may overflow + // Here we just check that the cast was performed + EXPECT_NE(data_ptr[0], 10000000000); // Expected overflow +} + +TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToBFloat16) { + std::vector float_data = {1.0f, 2.0f, 3.0f}; + auto tensor_impl = make_tensor_impl_ptr( + std::move(float_data), exec_aten::ScalarType::BFloat16); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::BFloat16); + + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_EQ(static_cast(data_ptr[0]), 1.0f); + EXPECT_EQ(static_cast(data_ptr[1]), 2.0f); + EXPECT_EQ(static_cast(data_ptr[2]), 3.0f); +} + +TEST_F(TensorImplPtrTest, InitializerListDoubleToHalf) { + auto tensor_impl = make_tensor_impl_ptr( + {1.5, 2.7, 3.14}, exec_aten::ScalarType::Half); + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 3); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half); + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_NEAR(static_cast(data_ptr[0]), 1.5f, 0.01); + EXPECT_NEAR(static_cast(data_ptr[1]), 2.7f, 0.01); + EXPECT_NEAR(static_cast(data_ptr[2]), 3.14f, 0.01); +} + +TEST_F(TensorImplPtrTest, InitializerListInt8ToInt64) { + auto tensor_impl = + make_tensor_impl_ptr({1, -2, 3, -4}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Long); + auto data_ptr = static_cast(tensor_impl->data()); + EXPECT_EQ(data_ptr[0], 1); + EXPECT_EQ(data_ptr[1], -2); + EXPECT_EQ(data_ptr[2], 3); + EXPECT_EQ(data_ptr[3], -4); +} diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 00614f24eb..291d19e06b 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -477,3 +477,118 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrInt64) { EXPECT_EQ(cloned_tensor->const_data_ptr()[3], 400); EXPECT_EQ(cloned_tensor->scalar_type(), exec_aten::ScalarType::Long); } + +TEST_F(TensorPtrTest, TensorDataCastingFromIntToFloat) { + std::vector int_data = {1, 2, 3, 4, 5, 6}; + auto tensor = make_tensor_ptr( + {2, 3}, std::move(int_data), {}, {}, exec_aten::ScalarType::Float); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_FLOAT_EQ(data_ptr[0], 1.0f); + EXPECT_FLOAT_EQ(data_ptr[5], 6.0f); +} + +TEST_F(TensorPtrTest, TensorDataCastingFromIntToDouble) { + std::vector int_data = {1, 2, 3}; + auto tensor = + make_tensor_ptr(std::move(int_data), exec_aten::ScalarType::Double); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Double); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_DOUBLE_EQ(data_ptr[0], 1.0); + EXPECT_DOUBLE_EQ(data_ptr[1], 2.0); + EXPECT_DOUBLE_EQ(data_ptr[2], 3.0); +} + +TEST_F(TensorPtrTest, TensorDataCastingFromFloatToHalf) { + std::vector float_data = {1.0f, 2.0f, 3.0f}; + auto tensor = + make_tensor_ptr(std::move(float_data), exec_aten::ScalarType::Half); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Half); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_EQ(static_cast(data_ptr[0]), 1.0f); + EXPECT_EQ(static_cast(data_ptr[1]), 2.0f); + EXPECT_EQ(static_cast(data_ptr[2]), 3.0f); +} + +TEST_F(TensorPtrTest, TensorDataCastingFromDoubleToFloat) { + std::vector double_data = {1.1, 2.2, 3.3}; + auto tensor = + make_tensor_ptr(std::move(double_data), exec_aten::ScalarType::Float); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_FLOAT_EQ(data_ptr[0], 1.1f); + EXPECT_FLOAT_EQ(data_ptr[1], 2.2f); + EXPECT_FLOAT_EQ(data_ptr[2], 3.3f); +} + +TEST_F(TensorPtrTest, TensorDataCastingFromInt64ToInt32) { + std::vector int64_data = {10000000000, 20000000000, 30000000000}; + auto tensor = + make_tensor_ptr(std::move(int64_data), exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Int); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_NE(data_ptr[0], 10000000000); // Expected overflow +} + +TEST_F(TensorPtrTest, TensorDataCastingFromFloatToBFloat16) { + std::vector float_data = {1.0f, 2.0f, 3.0f}; + auto tensor = + make_tensor_ptr(std::move(float_data), exec_aten::ScalarType::BFloat16); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::BFloat16); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_EQ(static_cast(data_ptr[0]), 1.0f); + EXPECT_EQ(static_cast(data_ptr[1]), 2.0f); + EXPECT_EQ(static_cast(data_ptr[2]), 3.0f); +} + +TEST_F(TensorPtrTest, InitializerListDoubleToHalf) { + auto tensor = + make_tensor_ptr({1.5, 2.7, 3.14}, exec_aten::ScalarType::Half); + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Half); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_NEAR(static_cast(data_ptr[0]), 1.5f, 0.01); + EXPECT_NEAR(static_cast(data_ptr[1]), 2.7f, 0.01); + EXPECT_NEAR(static_cast(data_ptr[2]), 3.14f, 0.01); +} + +TEST_F(TensorPtrTest, InitializerListInt8ToInt64) { + auto tensor = + make_tensor_ptr({1, -2, 3, -4}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Long); + + auto data_ptr = tensor->const_data_ptr(); + EXPECT_EQ(data_ptr[0], 1); + EXPECT_EQ(data_ptr[1], -2); + EXPECT_EQ(data_ptr[2], 3); + EXPECT_EQ(data_ptr[3], -4); +} diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 28cd46f6bc..e25c5e3692 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -1013,6 +1013,7 @@ inline ::executorch::aten::ScalarType promoteTypes( [&] { \ const auto& _st = TYPE; \ constexpr const char* et_switch_name = NAME; \ + (void)et_switch_name; /* Suppress unused var */ \ switch (_st) { \ __VA_ARGS__ \ default: \ diff --git a/runtime/core/portable_type/tensor.h b/runtime/core/portable_type/tensor.h index 149703a126..fb42e83710 100644 --- a/runtime/core/portable_type/tensor.h +++ b/runtime/core/portable_type/tensor.h @@ -85,6 +85,10 @@ class Tensor { return impl_->scalar_type(); } + inline ScalarType dtype() const { + return scalar_type(); + } + /// Returns the size in bytes of one element of the tensor. ssize_t element_size() const { return impl_->element_size(); diff --git a/runtime/core/portable_type/tensor_impl.h b/runtime/core/portable_type/tensor_impl.h index a6ae0ec6aa..c48149cd18 100644 --- a/runtime/core/portable_type/tensor_impl.h +++ b/runtime/core/portable_type/tensor_impl.h @@ -148,6 +148,10 @@ class TensorImpl { return type_; } + inline ScalarType dtype() const { + return scalar_type(); + } + /// Returns the size in bytes of one element of the tensor. ssize_t element_size() const;