From ba7b3c5b75b82b2e6a116c47322ae7f5e8e9d38d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 11 Apr 2023 16:34:33 -0700 Subject: [PATCH 1/2] [TIR] Add CUDA int4 tensor core intrinsics --- python/tvm/tir/tensor_intrin/cuda.py | 216 ++++++++++++++++++++++----- 1 file changed, 182 insertions(+), 34 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 3bc16f234fba..b3f8ad905f6a 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -20,6 +20,7 @@ from typing_extensions import Literal +import re from tvm.script import tir as T from tvm.tir.function import PrimFunc @@ -43,6 +44,16 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j): return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 +def get_tensor_core_load_offset_factor(dtype): + """get offset factor for tensor core load intrin""" + bits = re.search("(\d+)", dtype).group(0) + bits = int(bits) + if bits <= 4: + # sub-byte oeprations have different offset factor + return 128 // bits + return 256 // bits + + @register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): i, j = ind[0], ind[1] @@ -116,6 +127,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): col_dim = k_dim shmem_shape = (row_dim, col_dim) + offset_factor = get_tensor_core_load_offset_factor(dtype) @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: @@ -124,11 +136,16 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: shmem_shape, dtype, align=64, - offset_factor=16, + offset_factor=offset_factor, scope=shared_scope, ) warp = T.match_buffer( - warp_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=16, scope="warp" + warp_handle, + (WARP_SIZE, local_size), + dtype, + align=64, + offset_factor=offset_factor, + scope="warp", ) with T.block("root"): @@ -153,12 +170,17 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: shmem_shape, dtype, align=64, - offset_factor=16, + offset_factor=offset_factor, scope=shared_scope, strides=[s0, s1], ) warp = T.match_buffer( - warp_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=16, scope="warp" + warp_handle, + (WARP_SIZE, local_size), + dtype, + align=64, + offset_factor=offset_factor, + scope="warp", ) with T.block("root"): @@ -222,16 +244,34 @@ def maybe_swap(i, j): return j, i return i, j + in_offset_factor = get_tensor_core_load_offset_factor(in_dtype) + out_offset_factor = get_tensor_core_load_offset_factor(out_dtype) + @T.prim_func def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( - a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp" + a, + (WARP_SIZE, local_size), + in_dtype, + align=64, + offset_factor=in_offset_factor, + scope="warp", ) B = T.match_buffer( - b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp" + b, + (WARP_SIZE, local_size), + in_dtype, + align=64, + offset_factor=in_offset_factor, + scope="warp", ) C = T.match_buffer( - c, (WARP_SIZE, local_size_out), out_dtype, align=64, offset_factor=16, scope="warp" + c, + (WARP_SIZE, local_size_out), + out_dtype, + align=64, + offset_factor=out_offset_factor, + scope="warp", ) with T.block("root"): @@ -265,13 +305,28 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( - a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp" + a, + (WARP_SIZE, local_size), + in_dtype, + align=64, + offset_factor=in_offset_factor, + scope="warp", ) B = T.match_buffer( - b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp" + b, + (WARP_SIZE, local_size), + in_dtype, + align=64, + offset_factor=in_offset_factor, + scope="warp", ) C = T.match_buffer( - c, (WARP_SIZE, local_size_out), out_dtype, align=64, offset_factor=16, scope="warp" + c, + (WARP_SIZE, local_size_out), + out_dtype, + align=64, + offset_factor=out_offset_factor, + scope="warp", ) with T.block("root"): @@ -513,17 +568,29 @@ def get_wmma_load_intrin( """Generator of wmma_load intrins""" wmma_fragment_scope = "wmma.matrix_{}".format("b" if is_b else "a") layout = "col_major" if is_col_major else "row_major" + offset_factor = get_tensor_core_load_offset_factor(dtype) + + frag_m, frag_n = (k_dim, n_dim) if is_b else (m_dim, k_dim) + if is_col_major: + frag_m, frag_n = frag_n, frag_m @T.prim_func def wmma_load_desc(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=shared_scope) + A = T.match_buffer( + a, (frag_m, frag_n), dtype, align=64, offset_factor=offset_factor, scope=shared_scope + ) C = T.match_buffer( - c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=wmma_fragment_scope + c, + (frag_m, frag_n), + dtype, + align=64, + offset_factor=offset_factor, + scope=wmma_fragment_scope, ) with T.block("root"): - T.reads(A[0:m_dim, 0:n_dim]) - T.writes(C[0:m_dim, 0:n_dim]) - for i, j in T.grid(m_dim, n_dim): + T.reads(A[0:frag_m, 0:frag_n]) + T.writes(C[0:frag_m, 0:frag_n]) + for i, j in T.grid(frag_m, frag_n): with T.block("load"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = A[vii, vjj] @@ -536,32 +603,32 @@ def wmma_load_impl(a: T.handle, c: T.handle) -> None: d0 = T.int32() A = T.match_buffer( a, - (m_dim, n_dim), + (frag_m, frag_n), dtype, align=64, - offset_factor=16, + offset_factor=offset_factor, scope=shared_scope, strides=[s1, s0], ) C = T.match_buffer( c, - (m_dim, n_dim), + (frag_m, frag_n), dtype, align=64, - offset_factor=16, + offset_factor=offset_factor, scope=wmma_fragment_scope, strides=[d1, d0], ) with T.block("root"): - T.reads(A[0:m_dim, 0:n_dim]) - T.writes(C[0:m_dim, 0:n_dim]) + T.reads(A[0:frag_m, 0:frag_n]) + T.writes(C[0:frag_m, 0:frag_n]) T.evaluate( T.tvm_load_matrix_sync( C.data, m_dim, n_dim, k_dim, - get_wmma_fragment_index(C, d1, m_dim, n_dim), + get_wmma_fragment_index(C, d1, frag_m, frag_n), A.access_ptr("r"), s1, layout, @@ -577,11 +644,17 @@ def get_wmma_fill_intrin( ) -> Tuple[PrimFunc, PrimFunc]: """Generator of wmma_fill intrins""" zero = IntImm("int32", 0).astype(dtype) + offset_factor = get_tensor_core_load_offset_factor(dtype) @T.prim_func def wmma_fill_desc(c: T.handle) -> None: C = T.match_buffer( - c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator" + c, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=offset_factor, + scope="wmma.accumulator", ) with T.block("root"): T.reads() @@ -600,7 +673,7 @@ def wmma_fill_impl(c: T.handle) -> None: (m_dim, n_dim), dtype, align=64, - offset_factor=16, + offset_factor=offset_factor, scope="wmma.accumulator", strides=[d1, d0], ) @@ -626,13 +699,21 @@ def get_wmma_store_intrin( m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str ) -> Tuple[PrimFunc, PrimFunc]: """Generator of wmma_store intrins""" + offset_factor = get_tensor_core_load_offset_factor(dtype) @T.prim_func def wmma_store_desc(a: T.handle, c: T.handle) -> None: A = T.match_buffer( - a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator" + a, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=offset_factor, + scope="wmma.accumulator", + ) + C = T.match_buffer( + c, (m_dim, n_dim), dtype, align=64, offset_factor=offset_factor, scope=scope ) - C = T.match_buffer(c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope) with T.block("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) @@ -652,12 +733,18 @@ def wmma_store_impl(a: T.handle, c: T.handle) -> None: (m_dim, n_dim), dtype, align=64, - offset_factor=16, + offset_factor=offset_factor, scope="wmma.accumulator", strides=[d1, d0], ) C = T.match_buffer( - c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope, strides=[s1, s0] + c, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=offset_factor, + scope=scope, + strides=[s1, s0], ) with T.block("root"): T.reads(A[0:m_dim, 0:n_dim]) @@ -683,6 +770,8 @@ def get_wmma_sync_intrin( m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str, b_transposed: bool ) -> Tuple[PrimFunc, PrimFunc]: """Generator of wmma_sync intrins""" + in_offset_factor = get_tensor_core_load_offset_factor(in_dtype) + out_offset_factor = get_tensor_core_load_offset_factor(out_dtype) def maybe_cast(v): if in_dtype != out_dtype: @@ -699,18 +788,28 @@ def maybe_swap(i, j): @T.prim_func def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( - a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, scope="wmma.matrix_a" + a, + (m_dim, k_dim), + in_dtype, + align=64, + offset_factor=in_offset_factor, + scope="wmma.matrix_a", ) B = T.match_buffer( b, maybe_swap(k_dim, n_dim), in_dtype, align=64, - offset_factor=16, + offset_factor=in_offset_factor, scope="wmma.matrix_b", ) C = T.match_buffer( - c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, scope="wmma.accumulator" + c, + (m_dim, n_dim), + out_dtype, + align=64, + offset_factor=out_offset_factor, + scope="wmma.accumulator", ) with T.block("root"): @@ -738,7 +837,7 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: (m_dim, k_dim), in_dtype, align=64, - offset_factor=16, + offset_factor=in_offset_factor, scope="wmma.matrix_a", strides=[a1, a0], ) @@ -747,7 +846,7 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: maybe_swap(k_dim, n_dim), in_dtype, align=64, - offset_factor=16, + offset_factor=in_offset_factor, scope="wmma.matrix_b", strides=[b1, b0], ) @@ -756,7 +855,7 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: (m_dim, n_dim), out_dtype, align=64, - offset_factor=16, + offset_factor=out_offset_factor, scope="wmma.accumulator", strides=[c1, c0], ) @@ -817,6 +916,12 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True), ) +WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN = "wmma_sync_8x8x32_s4s4s32_trans" +TensorIntrin.register( + WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN, + *get_wmma_sync_intrin(8, 8, 32, "int4", "int32", True), +) + WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_F16_A_INTRIN, @@ -913,6 +1018,30 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True), ) +WMMA_LOAD_8x8x32_S4_A_INTRIN = "wmma_load_8x8x32_s4_a_shared" +TensorIntrin.register( + WMMA_LOAD_8x8x32_S4_A_INTRIN, + *get_wmma_load_intrin(8, 8, 32, "int4", "shared", False, False), +) + +WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN = "wmma_load_8x8x32_s4_a_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN, + *get_wmma_load_intrin(8, 8, 32, "int4", "shared.dyn", False, False), +) + +WMMA_LOAD_8x8x32_S4_B_TRANS_INTRIN = "wmma_load_8x8x32_s4_b_trans_shared" +TensorIntrin.register( + WMMA_LOAD_8x8x32_S4_B_TRANS_INTRIN, + *get_wmma_load_intrin(8, 8, 32, "int4", "shared", True, True), +) + +WMMA_LOAD_8x8x32_S4_B_TRANS_DYN_INTRIN = "wmma_load_8x8x32_s4_b_trans_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_8x8x32_S4_B_TRANS_DYN_INTRIN, + *get_wmma_load_intrin(8, 8, 32, "int4", "shared.dyn", True, True), +) + WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32" TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float32")) @@ -922,6 +1051,9 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: WMMA_FILL_16x16x16_S32_INTRIN = "wmma_fill_16x16x16_s32" TensorIntrin.register(WMMA_FILL_16x16x16_S32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "int32")) +WMMA_FILL_8x8x32_S32_INTRIN = "wmma_fill_8x8x32_s32" +TensorIntrin.register(WMMA_FILL_8x8x32_S32_INTRIN, *get_wmma_fill_intrin(8, 8, 32, "int32")) + WMMA_STORE_16x16x16_F32_SHARED_INTRIN = "wmma_store_16x16x16_f32_shared" TensorIntrin.register( WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "shared") @@ -955,6 +1087,17 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"), ) +WMMA_STORE_8x8x32_S32_SHARED_INTRIN = "wmma_store_8x8x32_s32_shared" +TensorIntrin.register( + WMMA_STORE_8x8x32_S32_SHARED_INTRIN, *get_wmma_store_intrin(8, 8, 32, "int32", "shared") +) + +WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN = "wmma_store_8x8x32_s32_shared_dyn" +TensorIntrin.register( + WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN, + *get_wmma_store_intrin(8, 8, 32, "int32", "shared.dyn"), +) + WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global" TensorIntrin.register( WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "global") @@ -970,6 +1113,11 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "global") ) +WMMA_STORE_8x8x32_S32_GLOBAL_INTRIN = "wmma_store_8x8x32_s32_global" +TensorIntrin.register( + WMMA_STORE_8x8x32_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(8, 8, 32, "int32", "global") +) + def get_wmma_intrin_group( load_scope: Literal["shared", "shared.dyn"], From bd0aa90faee3b00b9ef8ec17c3839656f94e849d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 11 Apr 2023 18:51:13 -0700 Subject: [PATCH 2/2] lint --- python/tvm/tir/tensor_intrin/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index b3f8ad905f6a..8d12a39ca79b 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,11 +16,11 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for tensorization on NVIDIA GPU.""" +import re from typing import Dict, Tuple from typing_extensions import Literal -import re from tvm.script import tir as T from tvm.tir.function import PrimFunc @@ -46,7 +46,7 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j): def get_tensor_core_load_offset_factor(dtype): """get offset factor for tensor core load intrin""" - bits = re.search("(\d+)", dtype).group(0) + bits = re.search(r"(\d+)", dtype).group(0) bits = int(bits) if bits <= 4: # sub-byte oeprations have different offset factor