From 1823f3539640d16493694f3888d4a64481eaa45d Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Mon, 10 May 2021 22:02:25 -0700 Subject: [PATCH] [Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (#8013) Previously, the WhileNode assumes that evaluating the loop condition will not introduce any additional labels. If this assumption is violated, such as for a WhileNode whose condition is an if/else statement, then the OpLoopMerge instruction appears in the wrong block. The unittest added exercises this code path, but doesn't yet trigger a failure. Once spvValidate is enabled for all vulkan codegen, then this unit test will catch the failure mode. Co-authored-by: Eric Lunderberg --- src/target/spirv/codegen_spirv.cc | 9 +++- .../unittest/test_target_codegen_vulkan.py | 50 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 8188744ce687..0c6deb28dca9 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -549,6 +549,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { void CodeGenSPIRV::VisitStmt_(const WhileNode* op) { spirv::Label head_label = builder_->NewLabel(); + spirv::Label condition_label = builder_->NewLabel(); spirv::Label body_label = builder_->NewLabel(); spirv::Label continue_label = builder_->NewLabel(); spirv::Label merge_label = builder_->NewLabel(); @@ -556,9 +557,15 @@ void CodeGenSPIRV::VisitStmt_(const WhileNode* op) { // Loop head builder_->StartLabel(head_label); - spirv::Value loop_cond = MakeValue(op->condition); uint32_t control = spv::LoopControlMaskNone; builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); + builder_->MakeInst(spv::OpBranch, condition_label); + + // Loop condition evaluation. The condition could contain if/else + // blocks that introduce additional labels, so the condition cannot + // be in the loop head's block. + builder_->StartLabel(condition_label); + spirv::Value loop_cond = MakeValue(op->condition); builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, weight_likely_branch_, 1); diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 9528741b6c52..56181db677d8 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -307,6 +307,56 @@ def test_scalar_params(num_int_params): test_scalar_params(2044) +@tvm.testing.parametrize_targets("vulkan") +def test_vulkan_while_if(target, dev): + def do_compute(A, B, n): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + + iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + iterations[0] = 0 + B[0] = 0 + + # WhileNode's condition is re-evaluated every loop. The + # if_then_else block introduces additional labels/blocks that + # must be kept separate from the WhileNode's block. + loop_condition = iterations[0] < tvm.tir.if_then_else(A[0] > 0, 10, 20) + with ib.while_loop(loop_condition): + iterations[0] += 1 + B[0] += iterations[0] + + return ib.get() + + n = 1 + dtype = "int32" + A = te.placeholder((n,), name="A", dtype=dtype) + + B = te.extern( + A.shape, + [A], + lambda ins, outs: do_compute(ins[0], outs[0], n), + dtype=dtype, + ) + s = te.create_schedule(B.op) + + # Point of failure would be here, at tvm.build. + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B], target) + + a = tvm.nd.array(np.array([5], dtype=A.dtype), dev) + b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), [55]) + + a = tvm.nd.array(np.array([-5], dtype=A.dtype), dev) + b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), [210]) + + if __name__ == "__main__": test_vector_comparison() test_vulkan_copy()