From 63f9cd6523bd827ea297c22cbbb74eaef9def931 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 26 Jun 2024 08:43:12 -0700 Subject: [PATCH] [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor (#17110) * [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor This makes the allocation go through memory planning and make it compatible with cuda graph. * lint * lint --- python/tvm/relax/testing/matmul.py | 3 +- src/relax/op/op_common.h | 3 ++ src/relax/transform/allocate_workspace.cc | 3 +- tests/python/relax/test_codegen_cutlass.py | 31 ++++++++++--------- .../test_transform_allocate_workspace.py | 10 +++--- 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/testing/matmul.py b/python/tvm/relax/testing/matmul.py index 0ce1225e7d3c..760ad1bdefab 100644 --- a/python/tvm/relax/testing/matmul.py +++ b/python/tvm/relax/testing/matmul.py @@ -25,7 +25,7 @@ def get_relax_matmul_module( x_shape, y_shape, in_dtype, - out_dtype, + out_dtype=None, transposed_y=False, bias_shape=None, activation=None, @@ -33,6 +33,7 @@ def get_relax_matmul_module( residual_activation=None, ): """Create a matmul op followd by epilogue operations.""" + out_dtype = out_dtype if out_dtype is not None else in_dtype with IRBuilder() as builder: with relax_builder.function(): R.func_name("main") diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 94474ce78444..ed6725e27012 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -558,6 +558,9 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d StringImm storage_scope = StringImm("global")); Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype); +Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index, + StringImm storage_scope = StringImm("global")); + /** * \brief Return the argument of the call. * Note: If this is a call_tir, return the arguments passed to the TIR func diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 4b26b590ef9a..fcfbf187714e 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -144,8 +144,7 @@ class WorkspaceProvider : ExprMutator { if (!workspace_var_main_.defined()) { auto shape = ShapeExpr({Integer(max_workspace_size_)}); auto ty = DataTypeImm(DataType::UInt(8)); - auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty); - auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape, ty); + auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0)); workspace_var_main_ = builder_->Emit(workspace, "workspace_main"); } for (const auto& binding : block_node->bindings) { diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 57f47ca6e6c0..969651f72fd4 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -104,7 +104,9 @@ def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1): mod = partition_for_cutlass(mod) if assert_all_bindings_fused: - assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings + assert ( + len(mod["main"].body.blocks[0].bindings) == num_final_bindings + ), "Not all bindings are fused. " + str(mod["main"]) codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) mod = codegen_pass(mod) @@ -714,7 +716,7 @@ def test_attention_offload(attention_size, attention_dtype): v_shape = (b, s_kv, n, h_v) mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype) - out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -751,7 +753,7 @@ def test_attention_bias_offload(attention_bias_size): mod = get_relax_attention_module( q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32" ) - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -786,9 +788,9 @@ def test_attention_scale_offload(attention_scale_size, attention_scale): q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale ) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2) else: - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -829,9 +831,9 @@ def test_attention_causal_offload(attention_causal_size, attention_causal): ) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2) else: - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -932,9 +934,9 @@ def test_stacked_attention_split_offload(stacked_attention_size): ) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2) else: - out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -950,9 +952,9 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size): qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape ) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2) else: - out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -1311,9 +1313,8 @@ def main( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) - workspace_main = R.vm.alloc_tensor( - lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8") + workspace_main = R.builtin.alloc_tensor( + R.shape([65536]), R.dtype("uint8"), R.prim_value(0) ) lv_1 = R.reshape(bias, R.shape([128, 16, 8])) lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8])) @@ -2419,7 +2420,7 @@ def test_sliding_window(): 1, 64, 64, 16, 8, 8, "none", "none", causal, "float16", window_size=window_size ) - out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) diff --git a/tests/python/relax/test_transform_allocate_workspace.py b/tests/python/relax/test_transform_allocate_workspace.py index aca6ea2fe83a..1198642d3f35 100644 --- a/tests/python/relax/test_transform_allocate_workspace.py +++ b/tests/python/relax/test_transform_allocate_workspace.py @@ -126,9 +126,8 @@ def entry_a( ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): cls = Expected with R.dataflow(): - lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) - workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor( - lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8") + workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor( + R.shape([65536]), R.dtype("uint8"), R.prim_value(0) ) gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1( q, k, v, workspace_main @@ -144,9 +143,8 @@ def entry_b( ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): cls = Expected with R.dataflow(): - lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) - workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor( - lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8") + workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor( + R.shape([65536]), R.dtype("uint8"), R.prim_value(0) ) gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1( q, k, v, workspace_main