diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 94efe6e1abfe..04ab7f80daa9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -52,6 +52,13 @@ from .op import tvm_tuple, tvm_struct_get, tvm_struct_set from .op import address_of, lookup_param, assume, undef from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error +from .op import ( + tvm_load_matrix_sync, + tvm_store_matrix_sync, + tvm_mma_sync, + tvm_bmma_sync, + tvm_fill_fragment, +) from .op import vectorlow, vectorhigh, vectorcombine from .op import infinity, reinterpret from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4f26b0f94765..cf7985e8f489 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -595,6 +595,242 @@ def tvm_throw_last_error(): return call_intrin("handle", "tir.tvm_throw_last_error") +def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): + """TVM intrinsic for tensor core load operators + + Parameters + ---------- + fragment : Var + The wmma fragment. + + m : UIntImm + The shape of wmma fragment. + + n : UIntImm + The shape of wmma fragment. + + k : UIntImm + The shape of wmma fragment. + + index : Expr + The fragment index. + + buffer_ptr : Expr + The fragment buffer pointer. + + stride : Expr + The fragment stride. + + layout : Literal["row_major", "column_major"] + The fragment layout. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.tvm_load_matrix_sync", + fragment, + m, + n, + k, + index, + buffer_ptr, + stride, + layout, + ) + + +def tvm_mma_sync( + fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c +): + """TVM intrinsic for tensor core mma_sync operators + + Parameters + ---------- + fragment_d : Var + The wmma fragment_d. + + index_d : Expr + The fragment_d index. + + fragment_a : Var + The wmma fragment_a. + + index_a : Expr + The fragment_a index. + + fragment_b : Var + The wmma fragment_b. + + index_b : Expr + The fragment_b index. + + fragment_c : Var + The wmma fragment_c. + + index_c : Expr + The fragment_c index. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.tvm_mma_sync", + fragment_d, + index_d, + fragment_a, + index_a, + fragment_b, + index_b, + fragment_c, + index_c, + ) + + +def tvm_bmma_sync( + fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c +): + """TVM intrinsic for tensor core bmma_sync operators + + Parameters + ---------- + fragment_d : Var + The bwmma fragment_d. + + index_d : Expr + The fragment_d index. + + fragment_a : Var + The bwmma fragment_a. + + index_a : Expr + The fragment_a index. + + fragment_b : Var + The bwmma fragment_b. + + index_b : Expr + The fragment_b index. + + fragment_c : Var + The bwmma fragment_c. + + index_c : Expr + The fragment_c index. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.tvm_bmma_sync", + fragment_d, + index_d, + fragment_a, + index_a, + fragment_b, + index_b, + fragment_c, + index_c, + ) + + +def tvm_fill_fragment(fragment, m, n, k, index, value): + """TVM intrinsic for tensor core fill_fragment operators + + Parameters + ---------- + fragment : Var + The wmma fragment + + m : UIntImm + The shape of wmma fragment. + + n : UIntImm + The shape of wmma fragment. + + k : UIntImm + The shape of wmma fragment. + + index : Expr + The fragment index. + + value : Expr + The value to be filled in fragment. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.tvm_fill_fragment", + fragment, + m, + n, + k, + index, + value, + ) + + +def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): + """TVM intrinsic for tensor core store operators + + Parameters + ---------- + fragment : Var + The wmma fragment. + + m : UIntImm + The shape of wmma fragment. + + n : UIntImm + The shape of wmma fragment. + + k : UIntImm + The shape of wmma fragment. + + index : Expr + The fragment index. + + buffer_ptr : Expr + The fragment buffer pointer. + + stride : Expr + The fragment stride. + + layout : Literal["row_major", "column_major"] + The fragment layout. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.tvm_store_matrix_sync", + fragment, + m, + n, + k, + index, + buffer_ptr, + stride, + layout, + ) + + def vectorlow(dtype, vec): """Get the low level half of the vector diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index 835a397ee3b2..5254e7326e24 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -104,6 +104,44 @@ def test_tir_op_tvm_throw_last_error(): assert expr.op.name == "tir.tvm_throw_last_error" +def test_tir_op_tvm_load_matrix_sync(): + buffer = tir.decl_buffer((16, 16), "float32") + x = tir.Var("x", "handle") + expr = tir.tvm_load_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major") + assert expr.op.name == "tir.tvm_load_matrix_sync" + + +def test_tir_op_tvm_store_matrix_sync(): + buffer = tir.decl_buffer((16, 16), "float32") + x = tir.Var("x", "handle") + expr = tir.tvm_store_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major") + assert expr.op.name == "tir.tvm_store_matrix_sync" + + +def test_tir_op_tvm_mma_sync(): + buffer_0 = tir.decl_buffer((16, 16), "float32") + buffer_1 = tir.decl_buffer((16, 16), "float32") + buffer_2 = tir.decl_buffer((16, 16), "float32") + buffer_3 = tir.decl_buffer((16, 16), "float32") + expr = tir.tvm_mma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0) + assert expr.op.name == "tir.tvm_mma_sync" + + +def test_tir_op_tvm_bmma_sync(): + buffer_0 = tir.decl_buffer((16, 16), "float32") + buffer_1 = tir.decl_buffer((16, 16), "float32") + buffer_2 = tir.decl_buffer((16, 16), "float32") + buffer_3 = tir.decl_buffer((16, 16), "float32") + expr = tir.tvm_bmma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0) + assert expr.op.name == "tir.tvm_bmma_sync" + + +def test_tir_op_tvm_fill_fragment(): + buffer = tir.decl_buffer((16, 16), "float32") + expr = tir.tvm_fill_fragment(buffer.data, 16, 16, 16, 0, 0) + assert expr.op.name == "tir.tvm_fill_fragment" + + def test_tir_op_vectorlow(): buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) vec = buffer.vload([0, 0], dtype="int8x16") @@ -165,6 +203,11 @@ def test_tir_op_TVMBackendFreeWorkspace(): test_tir_op_type_annotation() test_tir_op_tvm_access_ptr() test_tir_op_tvm_throw_last_error() + test_tir_op_tvm_load_matrix_sync(), + test_tir_op_tvm_store_matrix_sync(), + test_tir_op_tvm_mma_sync(), + test_tir_op_tvm_bmma_sync(), + test_tir_op_tvm_fill_fragment(), test_tir_op_vectorlow() test_tir_op_vectorhigh() test_tir_op_vectorcombine()