Skip to content

Commit

Permalink
[Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (
Browse files Browse the repository at this point in the history
…apache#8013)

Previously, the WhileNode assumes that evaluating the loop condition
will not introduce any additional labels.  If this assumption is
violated, such as for a WhileNode whose condition is an if/else
statement, then the OpLoopMerge instruction appears in the wrong
block.

The unittest added exercises this code path, but doesn't yet trigger a
failure.  Once spvValidate is enabled for all vulkan codegen, then
this unit test will catch the failure mode.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
2 people authored and Trevor Morris committed Jun 17, 2021
1 parent 542bd7a commit 1823f35
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,23 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {

void CodeGenSPIRV::VisitStmt_(const WhileNode* op) {
spirv::Label head_label = builder_->NewLabel();
spirv::Label condition_label = builder_->NewLabel();
spirv::Label body_label = builder_->NewLabel();
spirv::Label continue_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
builder_->MakeInst(spv::OpBranch, head_label);

// Loop head
builder_->StartLabel(head_label);
spirv::Value loop_cond = MakeValue(op->condition);
uint32_t control = spv::LoopControlMaskNone;
builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
builder_->MakeInst(spv::OpBranch, condition_label);

// Loop condition evaluation. The condition could contain if/else
// blocks that introduce additional labels, so the condition cannot
// be in the loop head's block.
builder_->StartLabel(condition_label);
spirv::Value loop_cond = MakeValue(op->condition);
builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
weight_likely_branch_, 1);

Expand Down
50 changes: 50 additions & 0 deletions tests/python/unittest/test_target_codegen_vulkan.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,56 @@ def test_scalar_params(num_int_params):
test_scalar_params(2044)


@tvm.testing.parametrize_targets("vulkan")
def test_vulkan_while_if(target, dev):
def do_compute(A, B, n):
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(A)
B = ib.buffer_ptr(B)

ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)

iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
iterations[0] = 0
B[0] = 0

# WhileNode's condition is re-evaluated every loop. The
# if_then_else block introduces additional labels/blocks that
# must be kept separate from the WhileNode's block.
loop_condition = iterations[0] < tvm.tir.if_then_else(A[0] > 0, 10, 20)
with ib.while_loop(loop_condition):
iterations[0] += 1
B[0] += iterations[0]

return ib.get()

n = 1
dtype = "int32"
A = te.placeholder((n,), name="A", dtype=dtype)

B = te.extern(
A.shape,
[A],
lambda ins, outs: do_compute(ins[0], outs[0], n),
dtype=dtype,
)
s = te.create_schedule(B.op)

# Point of failure would be here, at tvm.build.
with tvm.transform.PassContext(opt_level=3):
func = tvm.build(s, [A, B], target)

a = tvm.nd.array(np.array([5], dtype=A.dtype), dev)
b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), [55])

a = tvm.nd.array(np.array([-5], dtype=A.dtype), dev)
b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), [210])


if __name__ == "__main__":
test_vector_comparison()
test_vulkan_copy()
Expand Down

0 comments on commit 1823f35

Please sign in to comment.