Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Fix batch_matmul tensorcore legalize for transpose_b = False case #13618

Merged
merged 3 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
x_tensor, y_tensor = arg_types[0], arg_types[1]
dtype = x_tensor.dtype

if attrs.transpose_a:
B, K, M = x_tensor.shape
else:
B, M, K = x_tensor.shape

if attrs.transpose_b:
B, N, K = y_tensor.shape
else:
B, K, N = y_tensor.shape

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
x, y = inputs

B, M, K = x_tensor.shape
B, N, K = y_tensor.shape
if (
isinstance(B, tir.expr.Any)
or isinstance(M, tir.expr.Any)
Expand Down Expand Up @@ -96,9 +104,23 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
return None

logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x
y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y
out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype)

if attrs.transpose_a:
pad_width = ((0, 0), (0, dk), (0, dm))
else:
pad_width = ((0, 0), (0, dm), (0, dk))

x_ = relay.nn.pad(x, pad_width=pad_width) if dm or dk else x

if attrs.transpose_b:
pad_width = ((0, 0), (0, dn), (0, dk))
else:
pad_width = ((0, 0), (0, dk), (0, dn))

y_ = relay.nn.pad(y, pad_width=pad_width) if dn or dk else y

out_ = relay.nn.batch_matmul(x_, y_, **attrs)

out = (
relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape])
if dm or dn
Expand Down
43 changes: 36 additions & 7 deletions tests/python/relay/test_pass_legalize_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,27 @@ def expected():

@tvm.testing.uses_gpu
def test_legalize_batch_matmul():
def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
def _test_legalize_batch_matmul(
data_shape, kernel_shape, pad_shape, dtype, do_pad=True, transpose_a=False, transpose_b=True
):
"""test legalize dense to enable tensorcore"""
B, M, _ = data_shape
_, N, _ = kernel_shape
if transpose_a:
B, _, M = data_shape
else:
B, M, _ = data_shape

if transpose_b:
_, N, _ = kernel_shape
else:
_, _, N = kernel_shape

out_shape = (B, M, N)
dm, dk, dn = pad_shape

def before():
x = relay.var("x", shape=data_shape, dtype=dtype)
weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
y = relay.nn.batch_matmul(x, weight)
y = relay.nn.batch_matmul(x, weight, transpose_a=transpose_a, transpose_b=transpose_b)
y = relay.Function([x, weight], y)
return y

Expand All @@ -298,19 +308,31 @@ def legalize_batch_matmul(attrs, inputs, types):
def expected():
if not do_pad:
return before()

x = relay.var("x", shape=data_shape, dtype=dtype)
weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)

if dm or dk:
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
if transpose_a:
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm)))
else:
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
else:
x_pad = x
weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)

if dn or dk:
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
if transpose_b:
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
else:
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dk), (0, dn)))
else:
weight_pad = weight

y_pad = relay.nn.batch_matmul(
x_pad,
weight_pad,
transpose_a=transpose_a,
transpose_b=transpose_b,
)
if dm or dn:
y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape)
Expand Down Expand Up @@ -343,6 +365,13 @@ def expected():
_test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4")
_test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False)

_test_legalize_batch_matmul(
(16, 8, 16), (16, 16, 32), (0, 0, 0), "float16", False, transpose_b=False
)
_test_legalize_batch_matmul(
(16, 16, 8), (16, 32, 16), (0, 0, 0), "float16", False, transpose_a=True
)


if __name__ == "__main__":
test_legalize_conv2d_NHWC()
Expand Down