From ee72601953bdf592458c35a6dbc364d1adb3eb52 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Mon, 24 Apr 2023 12:41:02 +0530 Subject: [PATCH 1/2] [TIR] [Analysis] Calculate allocated memory at module level This patch modifies the existing analysis pass `tir.calculate_allocated_bytes` to accept an IRModule as an argument and return allocated bytes for all prim_funcs in the IRModule. --- include/tvm/tir/analysis.h | 8 ++- python/tvm/tir/analysis/analysis.py | 24 +++++-- .../analysis/calculate_allocated_memory.cc | 36 ++++++++-- ...tir_analysis_calculate_allocated_memory.py | 65 ++++++++++++++----- 4 files changed, 104 insertions(+), 29 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 4ed164e5ad45..bae65845e0a5 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -267,7 +267,13 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc * \param func The TIR PrimFunc for which the the allocated memory size to be calculated */ -TVM_DLL tvm::Map CalculateAllocatedBytes(const PrimFunc& func); +TVM_DLL tvm::Map> CalculateAllocatedBytes(const PrimFunc& func); + +/*! + * \brief Calculate the allocated memory per scope in bytes for each function inside the module + * \param mod The IRModule for which the the allocated memory size has to be calculated + */ +TVM_DLL tvm::Map> CalculateAllocatedBytes(const IRModule& mod); /*! * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 5feb630e4892..bb0a0a1b4121 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -201,20 +201,32 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore -def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]: +def calculate_allocated_bytes( + func_or_mod: Union[PrimFunc, IRModule] +) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]: """Calculate allocated memory per memory scope required by TIR PrimFuncs. Parameters ---------- - func: tvm.tir.PrimFunc - The function to be detected. + func_or_mod: Union[PrimFunc, IRModule] + The function or module to be detected. If a module is passed, allocated + memory is calcualted for all PrimFuncs inside the module Returns ------- - result : Dict[String, int] - Allocated memory size per scope in bytes. + result : Union[Dict[str, int], Dict[str, Dict[str, int]]] + Allocated memory size per scope in bytes for each function in the IRModule returned as a + dict with function names as keys and a dictionary of allcoated sizes as values. If a single + PrimFunc is passed, the function name is returned as "main" """ - return _ffi_api.calculate_allocated_bytes(func) # type: ignore + if not isinstance(func_or_mod, (PrimFunc, IRModule)): + raise TypeError( + f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}" + ) + allocated_mem = _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore + if isinstance(func_or_mod, PrimFunc): + return allocated_mem["main"] + return allocated_mem def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index ffdfc1f80162..8680f57e4cfd 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -79,16 +79,38 @@ void AllocationCalculator::VisitStmt_(const T* op) { _current_size[storage_scope] -= size; } -tvm::Map CalculateAllocatedBytes(const PrimFunc& func) { - return AllocationCalculator()(func); +tvm::Map > CalculateAllocatedBytes(const PrimFunc& func) { + tvm::Map > results; + results.Set("main", AllocationCalculator()(func)); + return results; } -TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) { - return CalculateAllocatedBytes(func); -}); +tvm::Map > CalculateAllocatedBytes(const IRModule& mod) { + tvm::Map > results; + for (const auto& kv : mod->functions) { + if (auto prim_func = kv.second.as()) { + String func_name = kv.first->name_hint; + results.Set(func_name, AllocationCalculator()(prim_func.value())); + } + } + return results; +} + +TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") + .set_body_typed([](ObjectRef obj) -> tvm::Map > { + if (auto func = obj.as()) { + return CalculateAllocatedBytes(func.value()); + } else if (auto mod = obj.as()) { + return CalculateAllocatedBytes(mod.value()); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: " + << obj->GetTypeKey(); + throw; + } + }); bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { - auto sizes = CalculateAllocatedBytes(func); + auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { return false; @@ -121,7 +143,7 @@ Pass VerifyVTCMLimit(Optional default_target) { } if (limit.has_value() && limit.value() > 0) { - auto sizes = CalculateAllocatedBytes(func); + auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); if (vtcm_allocated.IntValue() > limit.value()) { LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded " diff --git a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py index 2311bfbbef3c..217280cb5601 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py +++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py @@ -14,32 +14,42 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import pytest import tvm from tvm import tir from tvm.script import tir as T +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks -@T.prim_func -def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): - for i in T.serial(128): - with T.block("C"): - c[i] = a[i] * T.int8(2) +@tvm.script.ir_module +class Module: + @T.prim_func + def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): + for i in T.serial(128): + with T.block("C"): + c[i] = a[i] * T.int8(2) -@T.prim_func -def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): - B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm") - for i in T.serial(128): - with T.block("B"): - B[i] = a[i] * T.int8(2) - for i in T.serial(128): - with T.block("C"): - c[i] = B[i] * T.int8(3) + @T.prim_func + def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): + B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm") + for i in T.serial(128): + with T.block("B"): + B[i] = a[i] * T.int8(2) + for i in T.serial(128): + with T.block("C"): + c[i] = B[i] * T.int8(3) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on -@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), (scale_by_two_three, 256)]) +@pytest.mark.parametrize( + "primFunc,size", [(Module["scale_by_two"], 128), (Module["scale_by_two_three"], 256)] +) def test_scale_by(primFunc, size): """Test calculate allocated bytes per scope""" mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main")) @@ -97,5 +107,30 @@ def test_matmul_mix_scope(scope, size): assert sizes.get(scope, 0) == size +def test_full_mod_calculator(): + def apply_schedule(sch, func_name): + sch.work_on(func_name) + block_c = sch.get_block("C") + sch.cache_read(block_c, 0, storage_scope="global.vtcm") + + sch = tvm.tir.Schedule(Module, debug_mask="all") + apply_schedule(sch, "scale_by_two") + apply_schedule(sch, "scale_by_two_three") + mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + sizes = tvm.tir.analysis.calculate_allocated_bytes(mod) + assert "scale_by_two" in sizes, "Values for scale_by_two not found" + scale_by_two_sizes = sizes["scale_by_two"] + assert ( + "global.vtcm" in scale_by_two_sizes + ), "Expected global.vtcm allocation to be calculated scale_by_two" + assert scale_by_two_sizes["global.vtcm"] == 128, "Expected the calculated size to be 128" + scale_by_two_three_sizes = sizes["scale_by_two_three"] + assert ( + "global.vtcm" in scale_by_two_three_sizes + ), "Expected global.vtcm allocation to be calculated scale_by_two_three" + assert scale_by_two_three_sizes["global.vtcm"] == 256, "Expected the calculated size to be 256" + + if __name__ == "__main__": tvm.testing.main() From 794274caa2c1b9511432178b2d9ba31b10116399 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Mon, 24 Apr 2023 18:36:33 +0530 Subject: [PATCH 2/2] Fix docstring and modify python API to be consistent with c++ --- include/tvm/tir/analysis.h | 4 ++++ python/tvm/tir/analysis/analysis.py | 7 ++----- .../test_tir_analysis_calculate_allocated_memory.py | 4 ++++ 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index bae65845e0a5..3b5959e7816d 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -266,12 +266,16 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, /*! * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc * \param func The TIR PrimFunc for which the the allocated memory size to be calculated + * \return Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with + * key "main" and a Map of allocated sizes as values. */ TVM_DLL tvm::Map> CalculateAllocatedBytes(const PrimFunc& func); /*! * \brief Calculate the allocated memory per scope in bytes for each function inside the module * \param mod The IRModule for which the the allocated memory size has to be calculated + * \return Allocated memory size per scope in bytes for each function in the IRModule returned as a + Map with function names as keys and a Map of allocated sizes as values. */ TVM_DLL tvm::Map> CalculateAllocatedBytes(const IRModule& mod); diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index bb0a0a1b4121..387ea0498015 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -216,17 +216,14 @@ def calculate_allocated_bytes( ------- result : Union[Dict[str, int], Dict[str, Dict[str, int]]] Allocated memory size per scope in bytes for each function in the IRModule returned as a - dict with function names as keys and a dictionary of allcoated sizes as values. If a single + dict with function names as keys and a dict of allocated sizes as values. If a single PrimFunc is passed, the function name is returned as "main" """ if not isinstance(func_or_mod, (PrimFunc, IRModule)): raise TypeError( f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}" ) - allocated_mem = _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore - if isinstance(func_or_mod, PrimFunc): - return allocated_mem["main"] - return allocated_mem + return _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: diff --git a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py index 217280cb5601..cb3a663c0379 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py +++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py @@ -63,6 +63,8 @@ def test_scale_by(primFunc, size): mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) + assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"' + sizes = sizes["main"] assert sizes.get("global.vtcm", 0) == size @@ -104,6 +106,8 @@ def test_matmul_mix_scope(scope, size): mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) + assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"' + sizes = sizes["main"] assert sizes.get(scope, 0) == size