Skip to content

Commit

Permalink
Move type arg to the end to match Aten constructors. (#5379)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5379

.

Reviewed By: kirklandsign

Differential Revision: D62701089

fbshipit-source-id: 3f05961a43db9e6e372ee039c2d832227951fbf6
  • Loading branch information
shoumikhin authored and facebook-github-bot committed Sep 16, 2024
1 parent 0a501eb commit c252553
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 178 deletions.
8 changes: 4 additions & 4 deletions extension/tensor/tensor_impl_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ struct TensorImplPtrDeleter final {
} // namespace

TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type,
std::vector<exec_aten::SizesType> sizes,
void* data,
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type,
exec_aten::TensorShapeDynamism dynamism,
std::function<void(void*)> deleter) {
const auto dim = sizes.size();
Expand Down Expand Up @@ -129,24 +129,24 @@ TensorImplPtr make_tensor_impl_ptr(
}

TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType scalar_type,
std::vector<exec_aten::SizesType> sizes,
std::vector<uint8_t> data,
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type,
exec_aten::TensorShapeDynamism dynamism) {
ET_CHECK_MSG(
data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
exec_aten::elementSize(scalar_type),
exec_aten::elementSize(type),
"Data size is smaller than required by sizes and scalar type.");
auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<uint8_t>>(std::move(data));
return make_tensor_impl_ptr(
scalar_type,
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}
Expand Down
204 changes: 173 additions & 31 deletions extension/tensor/tensor_impl_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ namespace extension {

#ifndef USE_ATEN_LIB
/**
* A smart pointer type for managing the lifecycle of a TensorImpl.
* A smart pointer for managing the lifecycle of a TensorImpl.
*
* TensorImplPtr uses a shared pointer because multiple Tensor objects might
* share the same underlying data and metadata. This shared ownership model
* ensures that the TensorImpl is only destroyed when all references to it are
* gone, providing a safe and efficient way to manage shared tensor
* implementations. This abstraction is designed to be a safer and more
* convenient alternative to the original TensorImpl, which does not
* manage metadata by design.
* TensorImplPtr uses a shared pointer since multiple Tensor objects may
* share the same underlying data and metadata. This shared ownership ensures
* that the TensorImpl is destroyed only when all references to it are gone,
* providing a safe and efficient way to manage shared tensor implementations.
* It serves as a safer, more convenient alternative to the original TensorImpl,
* which does not manage its metadata by design.
*/
using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
#else
Expand All @@ -48,23 +47,23 @@ using TensorImplPtr =
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* @param type The scalar type of the tensor elements.
* @param sizes A vector specifying the size of each dimension.
* @param data A pointer to the data buffer.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @param deleter A custom deleter function for managing the lifetime of the
* data buffer. If provided, this deleter will be called when the managed
* TensorImpl object is destroyed.
* data buffer. If provided, this deleter is called when the managed TensorImpl
* is destroyed.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type,
std::vector<exec_aten::SizesType> sizes,
void* data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
std::function<void(void*)> deleter = nullptr);
Expand All @@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type. The deleter ensures that the data vector is properly
* managed and its lifetime is tied to the TensorImpl.
* @param sizes A vector specifying the size of each dimension.
* @param data A pointer to the data buffer.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @param deleter A custom deleter function for managing the lifetime of the
* data buffer. If provided, this deleter is called when the managed TensorImpl
* is destroyed.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
void* data,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
std::function<void(void*)> deleter = nullptr) {
return make_tensor_impl_ptr(
std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as a vector. The scalar type is automatically deduced from the vector's data
* type. The deleter ensures that the data vector is properly managed, with its
* lifetime tied to the TensorImpl.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <typename T = float>
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::vector<T> data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
constexpr exec_aten::ScalarType scalar_type =
runtime::CppTypeToScalarType<T>::value;
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
scalar_type,
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}
Expand All @@ -119,43 +145,159 @@ inline TensorImplPtr make_tensor_impl_ptr(
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param data A vector containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <typename T = float>
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<T> data,
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(data), {0}, {1}, dynamism);
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as an initializer list. The scalar type is automatically deduced from the
* initializer list's data type. The deleter ensures that the data is properly
* managed, with its lifetime tied to the TensorImpl.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::initializer_list<T> list,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
auto data = std::vector<T>(std::move(list));
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as an initializer list. The scalar type is automatically deduced
* from the initializer list's data type. The deleter ensures that the data is
* properly managed and its lifetime is tied to the TensorImpl.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::initializer_list<T> list,
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
}

/**
* Creates a TensorImplPtr to manage a Tensor with a single scalar value.
*
* @tparam T The C++ type of the scalar value.
* @param value The scalar value used for the Tensor.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
template <typename T>
inline TensorImplPtr make_tensor_impl_ptr(T value) {
return make_tensor_impl_ptr({}, std::vector<T>{value});
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
* and a scalar type to interpret the data. The vector is managed, and the
* memory's lifetime is tied to the TensorImpl.
* and a scalar type to interpret the data. The vector is managed, and its
* lifetime is tied to the TensorImpl.
*
* @param scalar_type The scalar type of the tensor elements.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the raw memory for the tensor's data.
* @param data A vector containing the raw memory buffer for the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType scalar_type,
std::vector<exec_aten::SizesType> sizes,
std::vector<uint8_t> data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
* and a scalar type to interpret the data. The vector is managed, and the
* memory's lifetime is tied to the TensorImpl.
*
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the raw memory for the tensor's data.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::vector<uint8_t> data,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
return make_tensor_impl_ptr(
std::move(sizes), std::move(data), {}, {}, type, dynamism);
}

} // namespace extension
} // namespace executorch
Loading

0 comments on commit c252553

Please sign in to comment.