Skip to content

Commit

Permalink
[CUDA] Fix dense tensorcore legalize type error when units is specifi…
Browse files Browse the repository at this point in the history
…ed (apache#9030)

* Fix dense tensorcore legalize type error when units is specified

* revert black change due to different version from CI
  • Loading branch information
masahi authored and ylc committed Jan 13, 2022
1 parent 2b1c050 commit 6a9d4ae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
6 changes: 6 additions & 0 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def _dense_legalize(attrs, inputs, arg_types):

x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) if dm or dk else x
y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) if dn or dk else y

# If units is explicitly specified, it is used to compute the output shape.
# We need to update units after padding to prevent a type error.
if attrs["units"] is not None:
new_attrs["units"] = N + dn

out_ = relay.nn.dense(x_, y_, **new_attrs)
out = (
relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape])
Expand Down
12 changes: 6 additions & 6 deletions tests/python/relay/test_pass_legalize_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def expected():

@tvm.testing.uses_gpu
def test_legalize_dense():
def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True, units=None):
"""test legalize dense to enable tensorcore"""
M, K = data_shape
N, _ = kernel_shape
Expand All @@ -216,7 +216,7 @@ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True
def before():
x = relay.var("x", shape=data_shape, dtype=dtype)
weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
y = relay.nn.dense(x, weight)
y = relay.nn.dense(x, weight, units)
y = relay.Function([x, weight], y)
return y

Expand All @@ -237,10 +237,7 @@ def expected():
weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk)))
else:
weight_pad = weight
y_pad = relay.nn.dense(
x_pad,
weight_pad,
)
y_pad = relay.nn.dense(x_pad, weight_pad, units=N + dn if units else None)
if dm or dn:
y = relay.strided_slice(y_pad, begin=[0, 0], end=out_shape)
else:
Expand All @@ -264,6 +261,9 @@ def expected():
_test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype)
_test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False)

# Test if units parameter is correctly updated
_test_legalize_dense((8, 16), (30, 16), (0, 0, 2), "float16", units=30)

_test_legalize_dense((8, 32), (32, 32), (0, 0, 0), "int4", False)
_test_legalize_dense((7, 32), (32, 32), (1, 0, 0), "int4")
_test_legalize_dense((8, 31), (32, 31), (0, 1, 0), "int4")
Expand Down

0 comments on commit 6a9d4ae

Please sign in to comment.