Skip to content

Commit

Permalink
[Fix][TIR] UnifyThreadBinding creating unit loop with annotation (#14588
Browse files Browse the repository at this point in the history
)

This PR fixes a behavior of the UnifyThreadBinding pass which (at one
place) assumes a return value is always a ForNode, which is not right.

To be more specific, when a thread-binding loop has an annotation,
the current behavior is assuming that the post-recursive-mutation value
is also a ForNode, and apply the previous annotation directly to the new
loop. However, the post-recursive-mutation value is also possibly not a
ForNode. In this case, the current behavior is incorrect.

This PR creates a new unit-length loop in this case to preserve the
annotation.

Thanks Bohan for catching this issue.

Co-authored-by: Bohan Hou <spectrometerh@gmail.com>
  • Loading branch information
MasterJH5574 and spectrometerHBH authored Apr 12, 2023
1 parent ca7c3d8 commit 40af75b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/tir/transforms/unify_thread_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,20 @@ class ThreadBindingUnifier : public StmtExprMutator {
if (annotations.empty()) {
return stmt;
}
For new_loop = Downcast<For>(stmt);
new_loop.CopyOnWrite()->annotations = std::move(annotations);
return std::move(new_loop);
if (const auto* loop = stmt.as<ForNode>()) {
For new_loop = GetRef<For>(loop);
new_loop.CopyOnWrite()->annotations = std::move(annotations);
return std::move(new_loop);
} else {
// Create a new unit loop with the annotation.
DataType dtype = op->loop_var->dtype;
return For(/*loop_var=*/Var("var", dtype), //
/*min=*/IntImm(dtype, 0), //
/*extent=*/IntImm(dtype, 1), //
/*kind=*/ForKind::kSerial, stmt, //
/*thread_binding=*/NullOpt, //
/*annotation=*/std::move(annotations));
}
}

template <typename Node>
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_tir_transform_unify_thread_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,31 @@ def test_implicit_block():
_check(element_wise_implicit_block, unified_element_wise_implicit_block)


def test_inner_binding_with_annotation():
@T.prim_func
def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(32, "blockIdx.x"):
for tx in T.thread_binding(2, "threadIdx.x", annotations={"my_annotation": 1}):
with T.block("block"):
v = T.axis.spatial(64, bx * 2 + tx)
B[v] = A[v]

@T.prim_func
def unified_inner_binding_with_annotation(
A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")
):
for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"):
for var in T.serial(1, annotations={"my_annotation": 1}):
with T.block("block"):
v = T.axis.spatial(64, blockIdx_x * 2 + threadIdx_x)
T.reads(A[v])
T.writes(B[v])
B[v] = A[v]

_check(inner_binding_with_annotation, unified_inner_binding_with_annotation)


def test_lower_te():
a = te.placeholder((32, 2, 2))
b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0)
Expand Down

0 comments on commit 40af75b

Please sign in to comment.