Skip to content

Commit

Permalink
[TVMScript] Fix printing ForNode annotations (apache#8891)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and ylc committed Jan 13, 2022
1 parent def4a57 commit 0074ec9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) {
res << Print(loop->thread_binding.value()->thread_tag);
}
if (!loop->annotations.empty()) {
res << ", annotation = {";
res << ", annotations = {";
res << PrintAnnotations(loop->annotations);
res << "}";
}
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2803,7 +2803,9 @@ def for_thread_binding(a: ty.handle, b: ty.handle) -> None:
B = tir.match_buffer(b, (16, 16), "float32")

for i in tir.thread_binding(0, 16, thread="threadIdx.x"):
for j in tir.thread_binding(0, 16, thread="threadIdx.y"):
for j in tir.thread_binding(
0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"}
):
A[i, j] = B[i, j] + tir.float32(1)


Expand All @@ -2818,6 +2820,7 @@ def test_for_thread_binding():
assert isinstance(rt_func.body.body, tir.stmt.For)
assert rt_func.body.body.kind == 4
assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y"
assert rt_func.body.body.annotations["attr_key"] == "attr_value"


@tvm.script.tir
Expand Down

0 comments on commit 0074ec9

Please sign in to comment.