Skip to content

Commit

Permalink
Cast the vector from deduced type to desired type if needed. (#5409)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5409

.

Reviewed By: kirklandsign

Differential Revision: D62807302

fbshipit-source-id: ce71b88c7588367def22e3baf5b835e21b42c8bf
  • Loading branch information
shoumikhin authored and facebook-github-bot committed Sep 17, 2024
1 parent c605bae commit e8a557c
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 36 deletions.
84 changes: 54 additions & 30 deletions extension/tensor/tensor_impl_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,57 @@ inline TensorImplPtr make_tensor_impl_ptr(
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as a vector. The scalar type is automatically deduced from the vector's data
* type. The deleter ensures that the data vector is properly managed, with its
* lifetime tied to the TensorImpl.
* as a vector. If the specified `type` differs from the deduced type of the
* vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::vector<T> data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
if (type != deduced_type) {
ET_CHECK_MSG(
runtime::canCast(deduced_type, type),
"Cannot cast deduced type to specified type.");
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
ET_SWITCH_REALHBBF16_TYPES(
type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] {
std::transform(
data.begin(),
data.end(),
reinterpret_cast<CTYPE*>(casted_data.data()),
[](const T& val) { return static_cast<CTYPE>(val); });
});
const auto raw_data_ptr = casted_data.data();
auto data_ptr =
std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
return make_tensor_impl_ptr(
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
Expand All @@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr(
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type. The deleter ensures that the data vector is properly
* managed and its lifetime is tied to the TensorImpl.
* This template overload is specialized for cases where tensor data is provided
* as a vector. If the specified `type` differs from the deduced type of the
* vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param data A vector containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
Expand All @@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
Expand All @@ -168,17 +195,19 @@ inline TensorImplPtr make_tensor_impl_ptr(
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as an initializer list. The scalar type is automatically deduced from the
* initializer list's data type. The deleter ensures that the data is properly
* managed, with its lifetime tied to the TensorImpl.
* as an initializer list. If the specified `type` differs from the deduced type
* of the initializer list's elements, and casting is allowed, the data will be
* cast to the specified `type`. This allows for flexible creation of tensors
* with data initializer list of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
Expand All @@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
auto data = std::vector<T>(std::move(list));
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
std::move(sizes),
raw_data_ptr,
std::vector<T>(std::move(list)),
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
dynamism);
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as an initializer list. The scalar type is automatically deduced
* from the initializer list's data type. The deleter ensures that the data is
* properly managed and its lifetime is tied to the TensorImpl.
* This template overload is specialized for cases where tensor data is provided
* as an initializer list. If the specified `type` differs from the deduced type
* of the initializer list's elements, and casting is allowed, the data will be
* cast to the specified `type`. This allows for flexible creation of tensors
* with data initializer list of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
Expand All @@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
Expand Down
32 changes: 26 additions & 6 deletions extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,18 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type.
* vector's data type. If the specified `type` differs from the deduced type of
* the vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand Down Expand Up @@ -228,10 +232,15 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type.
* vector's data type. If the specified `type` differs from the deduced type of
* the vector's elements, and casting is allowed, the data will be cast to the
* specified `type`. This allows for flexible creation of tensors with data
* vectors of one type and a different scalar type.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param data A vector containing the tensor's data.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand All @@ -251,15 +260,20 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload is specialized for cases where the tensor data is
* provided as an initializer list. The scalar type is automatically deduced
* from the initializer list's data type.
* from the initializer list's data type. If the specified `type` differs from
* the deduced type of the initializer list's elements, and casting is allowed,
* the data will be cast to the specified `type`. This allows for flexible
* creation of tensors with data vectors of one type and a different scalar
* type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand Down Expand Up @@ -288,11 +302,17 @@ inline TensorPtr make_tensor_ptr(
*
* This template overload allows creating a Tensor from an initializer list
* of data. The scalar type is automatically deduced from the type of the
* initializer list's elements.
* initializer list's elements. If the specified `type` differs from
* the deduced type of the initializer list's elements, and casting is allowed,
* the data will be cast to the specified `type`. This allows for flexible
* creation of tensors with data vectors of one type and a different scalar
* type.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param list An initializer list containing the tensor's data.
* @param type The scalar type of the tensor elements. If it differs from the
* deduced type, the data will be cast to this type if allowed.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorPtr that manages the newly created TensorImpl.
*/
Expand Down
125 changes: 125 additions & 0 deletions extension/tensor/test/tensor_impl_ptr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,128 @@ TEST_F(TensorImplPtrTest, StridesAndDimOrderMustMatchSizes) {
ET_EXPECT_DEATH(
{ auto _ = make_tensor_impl_ptr({3, 4}, data, {0}, {4, 1}); }, "");
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToFloat) {
std::vector<int32_t> int_data = {1, 2, 3, 4, 5, 6};
auto tensor_impl = make_tensor_impl_ptr(
{2, 3}, std::move(int_data), {}, {}, exec_aten::ScalarType::Float);

EXPECT_EQ(tensor_impl->dim(), 2);
EXPECT_EQ(tensor_impl->size(0), 2);
EXPECT_EQ(tensor_impl->size(1), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float);

auto data_ptr = static_cast<const float*>(tensor_impl->data());
EXPECT_FLOAT_EQ(data_ptr[0], 1.0f);
EXPECT_FLOAT_EQ(data_ptr[5], 6.0f);
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToDouble) {
std::vector<int32_t> int_data = {1, 2, 3};
auto tensor_impl =
make_tensor_impl_ptr(std::move(int_data), exec_aten::ScalarType::Double);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Double);

auto data_ptr = static_cast<const double*>(tensor_impl->data());
EXPECT_DOUBLE_EQ(data_ptr[0], 1.0);
EXPECT_DOUBLE_EQ(data_ptr[1], 2.0);
EXPECT_DOUBLE_EQ(data_ptr[2], 3.0);
}

TEST_F(TensorImplPtrTest, TensorDataCastingInvalidCast) {
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
ET_EXPECT_DEATH(
{
auto _ = make_tensor_impl_ptr(
std::move(float_data), exec_aten::ScalarType::Int);
},
"");
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToHalf) {
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
auto tensor_impl =
make_tensor_impl_ptr(std::move(float_data), exec_aten::ScalarType::Half);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half);

auto data_ptr = static_cast<const exec_aten::Half*>(tensor_impl->data());
EXPECT_EQ(static_cast<float>(data_ptr[0]), 1.0f);
EXPECT_EQ(static_cast<float>(data_ptr[1]), 2.0f);
EXPECT_EQ(static_cast<float>(data_ptr[2]), 3.0f);
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromDoubleToFloat) {
std::vector<double> double_data = {1.1, 2.2, 3.3};
auto tensor_impl = make_tensor_impl_ptr(
std::move(double_data), exec_aten::ScalarType::Float);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float);

auto data_ptr = static_cast<const float*>(tensor_impl->data());
EXPECT_FLOAT_EQ(data_ptr[0], 1.1f);
EXPECT_FLOAT_EQ(data_ptr[1], 2.2f);
EXPECT_FLOAT_EQ(data_ptr[2], 3.3f);
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromInt64ToInt32) {
std::vector<int64_t> int64_data = {10000000000, 20000000000, 30000000000};
auto tensor_impl =
make_tensor_impl_ptr(std::move(int64_data), exec_aten::ScalarType::Int);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Int);

auto data_ptr = static_cast<const int32_t*>(tensor_impl->data());
// Since the values exceed int32_t range, they may overflow
// Here we just check that the cast was performed
EXPECT_NE(data_ptr[0], 10000000000); // Expected overflow
}

TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToBFloat16) {
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
auto tensor_impl = make_tensor_impl_ptr(
std::move(float_data), exec_aten::ScalarType::BFloat16);

EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::BFloat16);

auto data_ptr = static_cast<const exec_aten::BFloat16*>(tensor_impl->data());
EXPECT_EQ(static_cast<float>(data_ptr[0]), 1.0f);
EXPECT_EQ(static_cast<float>(data_ptr[1]), 2.0f);
EXPECT_EQ(static_cast<float>(data_ptr[2]), 3.0f);
}

TEST_F(TensorImplPtrTest, InitializerListDoubleToHalf) {
auto tensor_impl = make_tensor_impl_ptr<double>(
{1.5, 2.7, 3.14}, exec_aten::ScalarType::Half);
EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 3);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half);
auto data_ptr = static_cast<const exec_aten::Half*>(tensor_impl->data());
EXPECT_NEAR(static_cast<float>(data_ptr[0]), 1.5f, 0.01);
EXPECT_NEAR(static_cast<float>(data_ptr[1]), 2.7f, 0.01);
EXPECT_NEAR(static_cast<float>(data_ptr[2]), 3.14f, 0.01);
}

TEST_F(TensorImplPtrTest, InitializerListInt8ToInt64) {
auto tensor_impl =
make_tensor_impl_ptr<int8_t>({1, -2, 3, -4}, exec_aten::ScalarType::Long);
EXPECT_EQ(tensor_impl->dim(), 1);
EXPECT_EQ(tensor_impl->size(0), 4);
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Long);
auto data_ptr = static_cast<const int64_t*>(tensor_impl->data());
EXPECT_EQ(data_ptr[0], 1);
EXPECT_EQ(data_ptr[1], -2);
EXPECT_EQ(data_ptr[2], 3);
EXPECT_EQ(data_ptr[3], -4);
}
Loading

0 comments on commit e8a557c

Please sign in to comment.