diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 0ba428014548..dbbf9e74903c 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -48,14 +48,22 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): x_tensor, y_tensor = arg_types[0], arg_types[1] dtype = x_tensor.dtype + if attrs.transpose_a: + B, K, M = x_tensor.shape + else: + B, M, K = x_tensor.shape + + if attrs.transpose_b: + B, N, K = y_tensor.shape + else: + B, K, N = y_tensor.shape + # Collect the output tensor. output_tensor = arg_types[2] # Collect the input exprs. x, y = inputs - B, M, K = x_tensor.shape - B, N, K = y_tensor.shape if ( isinstance(B, tir.expr.Any) or isinstance(M, tir.expr.Any) @@ -96,9 +104,23 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): return None logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y - out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) + + if attrs.transpose_a: + pad_width = ((0, 0), (0, dk), (0, dm)) + else: + pad_width = ((0, 0), (0, dm), (0, dk)) + + x_ = relay.nn.pad(x, pad_width=pad_width) if dm or dk else x + + if attrs.transpose_b: + pad_width = ((0, 0), (0, dn), (0, dk)) + else: + pad_width = ((0, 0), (0, dk), (0, dn)) + + y_ = relay.nn.pad(y, pad_width=pad_width) if dn or dk else y + + out_ = relay.nn.batch_matmul(x_, y_, **attrs) + out = ( relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape]) if dm or dn diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 0e3c171d87da..c9782aec1b2c 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -277,17 +277,27 @@ def expected(): @tvm.testing.uses_gpu def test_legalize_batch_matmul(): - def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): + def _test_legalize_batch_matmul( + data_shape, kernel_shape, pad_shape, dtype, do_pad=True, transpose_a=False, transpose_b=True + ): """test legalize dense to enable tensorcore""" - B, M, _ = data_shape - _, N, _ = kernel_shape + if transpose_a: + B, _, M = data_shape + else: + B, M, _ = data_shape + + if transpose_b: + _, N, _ = kernel_shape + else: + _, _, N = kernel_shape + out_shape = (B, M, N) dm, dk, dn = pad_shape def before(): x = relay.var("x", shape=data_shape, dtype=dtype) weight = relay.var("weight", shape=kernel_shape, dtype=dtype) - y = relay.nn.batch_matmul(x, weight) + y = relay.nn.batch_matmul(x, weight, transpose_a=transpose_a, transpose_b=transpose_b) y = relay.Function([x, weight], y) return y @@ -298,19 +308,31 @@ def legalize_batch_matmul(attrs, inputs, types): def expected(): if not do_pad: return before() + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) + if dm or dk: - x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) + if transpose_a: + x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm))) + else: + x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) + if dn or dk: - weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) + if transpose_b: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) + else: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dk), (0, dn))) else: weight_pad = weight + y_pad = relay.nn.batch_matmul( x_pad, weight_pad, + transpose_a=transpose_a, + transpose_b=transpose_b, ) if dm or dn: y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape) @@ -343,6 +365,13 @@ def expected(): _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4") _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False) + _test_legalize_batch_matmul( + (16, 8, 16), (16, 16, 32), (0, 0, 0), "float16", False, transpose_b=False + ) + _test_legalize_batch_matmul( + (16, 16, 8), (16, 32, 16), (0, 0, 0), "float16", False, transpose_a=True + ) + if __name__ == "__main__": test_legalize_conv2d_NHWC()