diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 1453128eeb677..c4a09d79cb12b 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -186,6 +186,16 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): wrap_topi_schedule(topi.rocm.schedule_dense), name="dense.rocm", ) + data, weights = inputs + if (data.dtype == "int8" + and weights.dtype == "int8" + and out_type.dtype == "int32" + ): + strategy.add_implementation( + wrap_compute_dense(topi.cuda.dense_int8), + wrap_topi_schedule(topi.cuda.schedule_dense_int8), + name="dense_int8.rocm", + ) if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( diff --git a/python/tvm/topi/cuda/conv2d_int8.py b/python/tvm/topi/cuda/conv2d_int8.py index e84412a41fab8..3c530445e92f0 100644 --- a/python/tvm/topi/cuda/conv2d_int8.py +++ b/python/tvm/topi/cuda/conv2d_int8.py @@ -312,8 +312,8 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output): _, rc_block = s[conv].split(rc_block, factor=4) target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys or "rocm" in target.keys: - do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + # if "vulkan" in target.keys or "rocm" in target.keys: + # do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product if do_tensorize: dtypes = (pad_data.dtype, packed_kernel.dtype) s[conv].tensorize(rc_block, dp4a("shared", "shared", "local", dtypes)) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 3a023ba11b7b6..582d3e62303a4 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -173,8 +173,9 @@ def _schedule_dense_int8(cfg, s, output): ko, kt = cfg["tile_k"].apply(s, CC, ko) target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys or "rocm" in target.keys: - do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + # if "vulkan" in target.keys or "rocm" in target.keys: + # do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + assert False if do_tensorize: dtypes = (data.dtype, weight.dtype) s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes)) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 860118531e513..b93236b8cee61 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -346,45 +346,45 @@ def get_ref_data(): tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) targets = [ - ( - "cuda", - lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), - topi.cuda.schedule_conv2d_NCHWc_int8, - 4, - False, - ), - # Disable on CI since it does not support spirv int8 dot product # ( - # "vulkan -from_device=0", + # "cuda", # lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), # topi.cuda.schedule_conv2d_NCHWc_int8, # 4, # False, # ), + # Disable on CI since it does not support spirv int8 dot product + ( + "rocm", + lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), + topi.cuda.schedule_conv2d_NCHWc_int8, + 4, + False, + ), ] build_only_aarch64 = platform.machine() != "aarch64" - targets.append( - ( - "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", - topi.arm_cpu.conv2d_NCHWc_int8, - topi.arm_cpu.schedule_conv2d_NCHWc_int8, - 8, - build_only_aarch64, - ) - ) - - if in_dtype == "int8": - targets.append( - ( - "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", - topi.arm_cpu.conv2d_NCHWc_int8, - topi.arm_cpu.schedule_conv2d_NCHWc_int8, - 8, - build_only_aarch64, - ) - ) + # targets.append( + # ( + # "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", + # topi.arm_cpu.conv2d_NCHWc_int8, + # topi.arm_cpu.schedule_conv2d_NCHWc_int8, + # 8, + # build_only_aarch64, + # ) + # ) + + # if in_dtype == "int8": + # targets.append( + # ( + # "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", + # topi.arm_cpu.conv2d_NCHWc_int8, + # topi.arm_cpu.schedule_conv2d_NCHWc_int8, + # 8, + # build_only_aarch64, + # ) + # ) for target, compute, schedule, oc_block_factor, build_only in targets: check_target(target, compute, schedule, oc_block_factor, build_only) @@ -517,6 +517,7 @@ def test_conv2d_nchw(in_dtype): with Int8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1) + return verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1) verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0)