diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 8461885b38ce..811e205fb2b3 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -549,6 +549,9 @@ class StorageInfo(Node): type of the "virtual devices" the expressions are stored on, and the sizes of each storage element.""" + def __init__(self, sids, dev_types, sizes): + self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes) + @property def storage_ids(self): return _ffi_api.StorageInfoStorageIds(self) @@ -560,3 +563,13 @@ def device_types(self): @property def storage_sizes(self): return _ffi_api.StorageInfoStorageSizes(self) + + +@tvm._ffi.register_object("relay.StaticMemoryPlan") +class StaticMemoryPlan(Node): + """StaticMemoryPlan + + The result of static memory planning.""" + + def __init__(self, expr_to_storage_info): + self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 507dd9371a97..8fd46817b817 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -74,6 +74,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.DensePackAttrs") +class DensePackAttrs(Attrs): + """Attributes for nn.contrib_dense_pack""" + + @tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs") class BatchMatmulAttrs(Attrs): """Attributes for nn.batch_matmul""" diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index ea0ab093aa1d..07dfe1768790 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -41,6 +41,23 @@ StorageInfo::StorageInfo(std::vector storage_ids, std::vector& sids, const Array& dev_types, + const Array& sizes_in_bytes) { + std::vector sids_v, sizes_v; + std::vector dev_types_v; + for (auto s : sids) { + sids_v.push_back(s); + } + for (auto d : dev_types) { + dev_types_v.push_back(static_cast(static_cast(d))); + } + for (auto s : sizes_in_bytes) { + sizes_v.push_back(s); + } + return StorageInfo(sids_v, dev_types_v, sizes_v); + }); + TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) { Array ids; for (auto id : si->storage_ids) { @@ -73,6 +90,11 @@ StaticMemoryPlan::StaticMemoryPlan(Map expr_to_storage_info) data_ = std::move(n); } +TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan") + .set_body_typed([](const Map& expr_to_storage_info) { + return StaticMemoryPlan(expr_to_storage_info); + }); + int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { if (expr_type->IsInstance()) { auto tuple_type = Downcast(expr_type);