Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] [Analysis] Calculate allocated memory at module level #14711

Merged
merged 2 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,18 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
/*!
* \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the the allocated memory size to be calculated
* \return Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with
* key "main" and a Map of allocated sizes as values.
*/
TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func);
TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> CalculateAllocatedBytes(const PrimFunc& func);

/*!
* \brief Calculate the allocated memory per scope in bytes for each function inside the module
* \param mod The IRModule for which the the allocated memory size has to be calculated
* \return Allocated memory size per scope in bytes for each function in the IRModule returned as a
Map with function names as keys and a Map of allocated sizes as values.
*/
TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> CalculateAllocatedBytes(const IRModule& mod);

/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
Expand Down
21 changes: 15 additions & 6 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,29 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in
return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore


def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:
def calculate_allocated_bytes(
func_or_mod: Union[PrimFunc, IRModule]
) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]:
"""Calculate allocated memory per memory scope required by TIR PrimFuncs.

Parameters
----------
func: tvm.tir.PrimFunc
The function to be detected.
func_or_mod: Union[PrimFunc, IRModule]
The function or module to be detected. If a module is passed, allocated
memory is calcualted for all PrimFuncs inside the module

Returns
-------
result : Dict[String, int]
Allocated memory size per scope in bytes.
result : Union[Dict[str, int], Dict[str, Dict[str, int]]]
Allocated memory size per scope in bytes for each function in the IRModule returned as a
dict with function names as keys and a dict of allocated sizes as values. If a single
PrimFunc is passed, the function name is returned as "main"
quic-sanirudh marked this conversation as resolved.
Show resolved Hide resolved
"""
return _ffi_api.calculate_allocated_bytes(func) # type: ignore
if not isinstance(func_or_mod, (PrimFunc, IRModule)):
raise TypeError(
f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}"
)
return _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore


def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
Expand Down
36 changes: 29 additions & 7 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,38 @@ void AllocationCalculator<T>::VisitStmt_(const T* op) {
_current_size[storage_scope] -= size;
}

tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
return AllocationCalculator<AllocateNode>()(func);
tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const PrimFunc& func) {
tvm::Map<String, tvm::Map<String, Integer> > results;
results.Set("main", AllocationCalculator<AllocateNode>()(func));
return results;
}

TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) {
return CalculateAllocatedBytes(func);
});
tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const IRModule& mod) {
tvm::Map<String, tvm::Map<String, Integer> > results;
for (const auto& kv : mod->functions) {
if (auto prim_func = kv.second.as<tir::PrimFunc>()) {
String func_name = kv.first->name_hint;
results.Set(func_name, AllocationCalculator<AllocateNode>()(prim_func.value()));
}
}
return results;
}

TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes")
.set_body_typed([](ObjectRef obj) -> tvm::Map<String, tvm::Map<String, Integer> > {
if (auto func = obj.as<PrimFunc>()) {
return CalculateAllocatedBytes(func.value());
} else if (auto mod = obj.as<IRModule>()) {
return CalculateAllocatedBytes(mod.value());
} else {
LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: "
<< obj->GetTypeKey();
throw;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether "throw" is required here? LogFatal throws InternalError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I actually saw multiple places within the codebase which had throw after a LOG(FATAL). I looked into the code and see that the InternalError is thrown from the call to LogFatalImpl.

As far as I can tell, LogFatalImpl can be overridden by targets, so maybe that's why the throws are needed as we can't know for sure whether all implementations of LOG(FATAL) would throw InternalError. I'm not completely sure about this, but this is my understanding from looking at the code.

Would love to know if you had any more insights and am happy to remove the throw if you think it's not needed, thanks.

}
});

bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
auto sizes = CalculateAllocatedBytes(func);
auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
return false;
Expand Down Expand Up @@ -121,7 +143,7 @@ Pass VerifyVTCMLimit(Optional<Target> default_target) {
}

if (limit.has_value() && limit.value() > 0) {
auto sizes = CalculateAllocatedBytes(func);
auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (vtcm_allocated.IntValue() > limit.value()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import pytest

import tvm
from tvm import tir
from tvm.script import tir as T

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks

@T.prim_func
def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
for i in T.serial(128):
with T.block("C"):
c[i] = a[i] * T.int8(2)
@tvm.script.ir_module
class Module:
@T.prim_func
def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
for i in T.serial(128):
with T.block("C"):
c[i] = a[i] * T.int8(2)


@T.prim_func
def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
for i in T.serial(128):
with T.block("B"):
B[i] = a[i] * T.int8(2)
for i in T.serial(128):
with T.block("C"):
c[i] = B[i] * T.int8(3)
@T.prim_func
def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
for i in T.serial(128):
with T.block("B"):
B[i] = a[i] * T.int8(2)
for i in T.serial(128):
with T.block("C"):
c[i] = B[i] * T.int8(3)

# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on


@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), (scale_by_two_three, 256)])
@pytest.mark.parametrize(
"primFunc,size", [(Module["scale_by_two"], 128), (Module["scale_by_two_three"], 256)]
)
def test_scale_by(primFunc, size):
"""Test calculate allocated bytes per scope"""
mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main"))
Expand All @@ -53,6 +63,8 @@ def test_scale_by(primFunc, size):
mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"'
sizes = sizes["main"]
assert sizes.get("global.vtcm", 0) == size


Expand Down Expand Up @@ -94,8 +106,35 @@ def test_matmul_mix_scope(scope, size):
mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"'
sizes = sizes["main"]
assert sizes.get(scope, 0) == size


def test_full_mod_calculator():
def apply_schedule(sch, func_name):
sch.work_on(func_name)
block_c = sch.get_block("C")
sch.cache_read(block_c, 0, storage_scope="global.vtcm")

sch = tvm.tir.Schedule(Module, debug_mask="all")
apply_schedule(sch, "scale_by_two")
apply_schedule(sch, "scale_by_two_three")
mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)
assert "scale_by_two" in sizes, "Values for scale_by_two not found"
scale_by_two_sizes = sizes["scale_by_two"]
assert (
"global.vtcm" in scale_by_two_sizes
), "Expected global.vtcm allocation to be calculated scale_by_two"
assert scale_by_two_sizes["global.vtcm"] == 128, "Expected the calculated size to be 128"
scale_by_two_three_sizes = sizes["scale_by_two_three"]
assert (
"global.vtcm" in scale_by_two_three_sizes
), "Expected global.vtcm allocation to be calculated scale_by_two_three"
assert scale_by_two_three_sizes["global.vtcm"] == 256, "Expected the calculated size to be 256"


if __name__ == "__main__":
tvm.testing.main()