Skip to content

Commit

Permalink
[Relax] Alloc BYOC workspace with R.builtin.alloc_tensor (#17110)
Browse files Browse the repository at this point in the history
* [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor

This makes the allocation go through memory planning and make it
compatible with cuda graph.

* lint

* lint
  • Loading branch information
vinx13 authored Jun 26, 2024
1 parent 02fe0c5 commit 63f9cd6
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relax/testing/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ def get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
out_dtype=None,
transposed_y=False,
bias_shape=None,
activation=None,
residual_bin_op=None,
residual_activation=None,
):
"""Create a matmul op followd by epilogue operations."""
out_dtype = out_dtype if out_dtype is not None else in_dtype
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,9 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d
StringImm storage_scope = StringImm("global"));
Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype);

Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index,
StringImm storage_scope = StringImm("global"));

/**
* \brief Return the argument of the call.
* Note: If this is a call_tir, return the arguments passed to the TIR func
Expand Down
3 changes: 1 addition & 2 deletions src/relax/transform/allocate_workspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ class WorkspaceProvider : ExprMutator {
if (!workspace_var_main_.defined()) {
auto shape = ShapeExpr({Integer(max_workspace_size_)});
auto ty = DataTypeImm(DataType::UInt(8));
auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty);
auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape, ty);
auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0));
workspace_var_main_ = builder_->Emit(workspace, "workspace_main");
}
for (const auto& binding : block_node->bindings) {
Expand Down
31 changes: 16 additions & 15 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1):
mod = partition_for_cutlass(mod)

if assert_all_bindings_fused:
assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
assert (
len(mod["main"].body.blocks[0].bindings) == num_final_bindings
), "Not all bindings are fused. " + str(mod["main"])

codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
mod = codegen_pass(mod)
Expand Down Expand Up @@ -714,7 +716,7 @@ def test_attention_offload(attention_size, attention_dtype):
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)

Expand Down Expand Up @@ -751,7 +753,7 @@ def test_attention_bias_offload(attention_bias_size):
mod = get_relax_attention_module(
q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)

Expand Down Expand Up @@ -786,9 +788,9 @@ def test_attention_scale_offload(attention_scale_size, attention_scale):
q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)
else:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


Expand Down Expand Up @@ -829,9 +831,9 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
)

if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)
else:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


Expand Down Expand Up @@ -932,9 +934,9 @@ def test_stacked_attention_split_offload(stacked_attention_size):
)

if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2)
else:
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


Expand All @@ -950,9 +952,9 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size):
qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2)
else:
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


Expand Down Expand Up @@ -1311,9 +1313,8 @@ def main(
R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8"))
workspace_main = R.vm.alloc_tensor(
lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
workspace_main = R.builtin.alloc_tensor(
R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
)
lv_1 = R.reshape(bias, R.shape([128, 16, 8]))
lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8]))
Expand Down Expand Up @@ -2419,7 +2420,7 @@ def test_sliding_window():
1, 64, 64, 16, 8, 8, "none", "none", causal, "float16", window_size=window_size
)

out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)

Expand Down
10 changes: 4 additions & 6 deletions tests/python/relax/test_transform_allocate_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def entry_a(
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
cls = Expected
with R.dataflow():
lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8"))
workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor(
lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor(
R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
)
gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1(
q, k, v, workspace_main
Expand All @@ -144,9 +143,8 @@ def entry_b(
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
cls = Expected
with R.dataflow():
lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8"))
workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor(
lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor(
R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
)
gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1(
q, k, v, workspace_main
Expand Down

0 comments on commit 63f9cd6

Please sign in to comment.