Skip to content

Commit

Permalink
[TIR] [Analysis] Calculate allocated memory at module level (#14711)
Browse files Browse the repository at this point in the history
* [TIR] [Analysis] Calculate allocated memory at module level

This patch modifies the existing analysis pass
`tir.calculate_allocated_bytes` to accept an IRModule as an argument and
return allocated bytes for all prim_funcs in the IRModule.

* Fix docstring and modify python API to be consistent with c++
  • Loading branch information
quic-sanirudh authored Apr 25, 2023
1 parent 1a17139 commit f5ab3f0
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 29 deletions.
12 changes: 11 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,18 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
/*!
* \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the the allocated memory size to be calculated
* \return Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with
* key "main" and a Map of allocated sizes as values.
*/
TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func);
TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> CalculateAllocatedBytes(const PrimFunc& func);

/*!
* \brief Calculate the allocated memory per scope in bytes for each function inside the module
* \param mod The IRModule for which the the allocated memory size has to be calculated
* \return Allocated memory size per scope in bytes for each function in the IRModule returned as a
Map with function names as keys and a Map of allocated sizes as values.
*/
TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> CalculateAllocatedBytes(const IRModule& mod);

/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
Expand Down
21 changes: 15 additions & 6 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,29 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in
return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore


def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:
def calculate_allocated_bytes(
func_or_mod: Union[PrimFunc, IRModule]
) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]:
"""Calculate allocated memory per memory scope required by TIR PrimFuncs.
Parameters
----------
func: tvm.tir.PrimFunc
The function to be detected.
func_or_mod: Union[PrimFunc, IRModule]
The function or module to be detected. If a module is passed, allocated
memory is calcualted for all PrimFuncs inside the module
Returns
-------
result : Dict[String, int]
Allocated memory size per scope in bytes.
result : Union[Dict[str, int], Dict[str, Dict[str, int]]]
Allocated memory size per scope in bytes for each function in the IRModule returned as a
dict with function names as keys and a dict of allocated sizes as values. If a single
PrimFunc is passed, the function name is returned as "main"
"""
return _ffi_api.calculate_allocated_bytes(func) # type: ignore
if not isinstance(func_or_mod, (PrimFunc, IRModule)):
raise TypeError(
f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}"
)
return _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore


def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
Expand Down
36 changes: 29 additions & 7 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,38 @@ void AllocationCalculator<T>::VisitStmt_(const T* op) {
_current_size[storage_scope] -= size;
}

tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
return AllocationCalculator<AllocateNode>()(func);
tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const PrimFunc& func) {
tvm::Map<String, tvm::Map<String, Integer> > results;
results.Set("main", AllocationCalculator<AllocateNode>()(func));
return results;
}

TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) {
return CalculateAllocatedBytes(func);
});
tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const IRModule& mod) {
tvm::Map<String, tvm::Map<String, Integer> > results;
for (const auto& kv : mod->functions) {
if (auto prim_func = kv.second.as<tir::PrimFunc>()) {
String func_name = kv.first->name_hint;
results.Set(func_name, AllocationCalculator<AllocateNode>()(prim_func.value()));
}
}
return results;
}

TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes")
.set_body_typed([](ObjectRef obj) -> tvm::Map<String, tvm::Map<String, Integer> > {
if (auto func = obj.as<PrimFunc>()) {
return CalculateAllocatedBytes(func.value());
} else if (auto mod = obj.as<IRModule>()) {
return CalculateAllocatedBytes(mod.value());
} else {
LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: "
<< obj->GetTypeKey();
throw;
}
});

bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
auto sizes = CalculateAllocatedBytes(func);
auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
return false;
Expand Down Expand Up @@ -121,7 +143,7 @@ Pass VerifyVTCMLimit(Optional<Target> default_target) {
}

if (limit.has_value() && limit.value() > 0) {
auto sizes = CalculateAllocatedBytes(func);
auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (vtcm_allocated.IntValue() > limit.value()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import pytest

import tvm
from tvm import tir
from tvm.script import tir as T

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks

@T.prim_func
def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
for i in T.serial(128):
with T.block("C"):
c[i] = a[i] * T.int8(2)
@tvm.script.ir_module
class Module:
@T.prim_func
def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
for i in T.serial(128):
with T.block("C"):
c[i] = a[i] * T.int8(2)


@T.prim_func
def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
for i in T.serial(128):
with T.block("B"):
B[i] = a[i] * T.int8(2)
for i in T.serial(128):
with T.block("C"):
c[i] = B[i] * T.int8(3)
@T.prim_func
def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
for i in T.serial(128):
with T.block("B"):
B[i] = a[i] * T.int8(2)
for i in T.serial(128):
with T.block("C"):
c[i] = B[i] * T.int8(3)

# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on


@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), (scale_by_two_three, 256)])
@pytest.mark.parametrize(
"primFunc,size", [(Module["scale_by_two"], 128), (Module["scale_by_two_three"], 256)]
)
def test_scale_by(primFunc, size):
"""Test calculate allocated bytes per scope"""
mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main"))
Expand All @@ -53,6 +63,8 @@ def test_scale_by(primFunc, size):
mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"'
sizes = sizes["main"]
assert sizes.get("global.vtcm", 0) == size


Expand Down Expand Up @@ -94,8 +106,35 @@ def test_matmul_mix_scope(scope, size):
mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"'
sizes = sizes["main"]
assert sizes.get(scope, 0) == size


def test_full_mod_calculator():
def apply_schedule(sch, func_name):
sch.work_on(func_name)
block_c = sch.get_block("C")
sch.cache_read(block_c, 0, storage_scope="global.vtcm")

sch = tvm.tir.Schedule(Module, debug_mask="all")
apply_schedule(sch, "scale_by_two")
apply_schedule(sch, "scale_by_two_three")
mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)
assert "scale_by_two" in sizes, "Values for scale_by_two not found"
scale_by_two_sizes = sizes["scale_by_two"]
assert (
"global.vtcm" in scale_by_two_sizes
), "Expected global.vtcm allocation to be calculated scale_by_two"
assert scale_by_two_sizes["global.vtcm"] == 128, "Expected the calculated size to be 128"
scale_by_two_three_sizes = sizes["scale_by_two_three"]
assert (
"global.vtcm" in scale_by_two_three_sizes
), "Expected global.vtcm allocation to be calculated scale_by_two_three"
assert scale_by_two_three_sizes["global.vtcm"] == 256, "Expected the calculated size to be 256"


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit f5ab3f0

Please sign in to comment.