Skip to content

Commit

Permalink
[TOPI] Fix batch_matmul tensorcore legalize for transpose_b = False c…
Browse files Browse the repository at this point in the history
…ase (apache#13618)

* fixed tensor core batch_matmul legalize for transpose_b = False case

* add test

* clean up
  • Loading branch information
masahi authored and Mikael Sevenier committed Dec 29, 2022
1 parent f53dc45 commit 5b028dc
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 12 deletions.
32 changes: 27 additions & 5 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
x_tensor, y_tensor = arg_types[0], arg_types[1]
dtype = x_tensor.dtype

if attrs.transpose_a:
B, K, M = x_tensor.shape
else:
B, M, K = x_tensor.shape

if attrs.transpose_b:
B, N, K = y_tensor.shape
else:
B, K, N = y_tensor.shape

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
x, y = inputs

B, M, K = x_tensor.shape
B, N, K = y_tensor.shape
if (
isinstance(B, tir.expr.Any)
or isinstance(M, tir.expr.Any)
Expand Down Expand Up @@ -96,9 +104,23 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
return None

logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x
y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y
out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype)

if attrs.transpose_a:
pad_width = ((0, 0), (0, dk), (0, dm))
else:
pad_width = ((0, 0), (0, dm), (0, dk))

x_ = relay.nn.pad(x, pad_width=pad_width) if dm or dk else x

if attrs.transpose_b:
pad_width = ((0, 0), (0, dn), (0, dk))
else:
pad_width = ((0, 0), (0, dk), (0, dn))

y_ = relay.nn.pad(y, pad_width=pad_width) if dn or dk else y

out_ = relay.nn.batch_matmul(x_, y_, **attrs)

out = (
relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape])
if dm or dn
Expand Down
43 changes: 36 additions & 7 deletions tests/python/relay/test_pass_legalize_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,27 @@ def expected():

@tvm.testing.uses_gpu
def test_legalize_batch_matmul():
def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
def _test_legalize_batch_matmul(
data_shape, kernel_shape, pad_shape, dtype, do_pad=True, transpose_a=False, transpose_b=True
):
"""test legalize dense to enable tensorcore"""
B, M, _ = data_shape
_, N, _ = kernel_shape
if transpose_a:
B, _, M = data_shape
else:
B, M, _ = data_shape

if transpose_b:
_, N, _ = kernel_shape
else:
_, _, N = kernel_shape

out_shape = (B, M, N)
dm, dk, dn = pad_shape

def before():
x = relay.var("x", shape=data_shape, dtype=dtype)
weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
y = relay.nn.batch_matmul(x, weight)
y = relay.nn.batch_matmul(x, weight, transpose_a=transpose_a, transpose_b=transpose_b)
y = relay.Function([x, weight], y)
return y

Expand All @@ -298,19 +308,31 @@ def legalize_batch_matmul(attrs, inputs, types):
def expected():
if not do_pad:
return before()

x = relay.var("x", shape=data_shape, dtype=dtype)
weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)

if dm or dk:
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
if transpose_a:
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm)))
else:
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
else:
x_pad = x
weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)

if dn or dk:
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
if transpose_b:
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
else:
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dk), (0, dn)))
else:
weight_pad = weight

y_pad = relay.nn.batch_matmul(
x_pad,
weight_pad,
transpose_a=transpose_a,
transpose_b=transpose_b,
)
if dm or dn:
y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape)
Expand Down Expand Up @@ -343,6 +365,13 @@ def expected():
_test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4")
_test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False)

_test_legalize_batch_matmul(
(16, 8, 16), (16, 16, 32), (0, 0, 0), "float16", False, transpose_b=False
)
_test_legalize_batch_matmul(
(16, 16, 8), (16, 32, 16), (0, 0, 0), "float16", False, transpose_a=True
)


if __name__ == "__main__":
test_legalize_conv2d_NHWC()
Expand Down

0 comments on commit 5b028dc

Please sign in to comment.