From 120fd96e80307b4301ee3fc93e6793e0b40485f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 08:27:27 +0900 Subject: [PATCH] use buffer syntax sugar --- python/tvm/tir/tensor_intrin/x86.py | 18 ++++++++++-------- .../unittest/test_meta_schedule_tune_relay.py | 2 +- .../unittest/test_tir_schedule_tensorize.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index 84b86ed6b202..6fda9484df42 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -23,11 +23,9 @@ @T.prim_func -def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - +def dot_product_16x4_u8i8i32_desc( + A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] +) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) @@ -41,7 +39,9 @@ def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func -def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None: +def dot_product_16x4_u8i8i32_vnni_impl( + A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] +) -> None: A = T.match_buffer(a, (4,), "uint8", offset_factor=1) B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) C = T.match_buffer(c, (16,), "int32", offset_factor=1) @@ -66,6 +66,8 @@ def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None: ) -VNNI_INTRIN = "dot_16x4_vnni" +VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni" -TensorIntrin.register(VNNI_INTRIN, dot_product_16x4_desc, dot_product_16x4_vnni_impl) +TensorIntrin.register( + VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl +) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index fa59badc5da8..a9da41f7e6aa 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -42,7 +42,7 @@ from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.trace import Trace -from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN logging.basicConfig() diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 548543630145..3abdb0e93c61 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -22,7 +22,7 @@ from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks