Skip to content

Commit

Permalink
[TIR] Expose WMMA-related TensorCore builtins (#12589)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 and yongwww authored Aug 25, 2022
1 parent 9aac161 commit b387384
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@
from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
from .op import address_of, lookup_param, assume, undef
from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
from .op import (
tvm_load_matrix_sync,
tvm_store_matrix_sync,
tvm_mma_sync,
tvm_bmma_sync,
tvm_fill_fragment,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
236 changes: 236 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,242 @@ def tvm_throw_last_error():
return call_intrin("handle", "tir.tvm_throw_last_error")


def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core load operators
Parameters
----------
fragment : Var
The wmma fragment.
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
buffer_ptr : Expr
The fragment buffer pointer.
stride : Expr
The fragment stride.
layout : Literal["row_major", "column_major"]
The fragment layout.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_load_matrix_sync",
fragment,
m,
n,
k,
index,
buffer_ptr,
stride,
layout,
)


def tvm_mma_sync(
fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
):
"""TVM intrinsic for tensor core mma_sync operators
Parameters
----------
fragment_d : Var
The wmma fragment_d.
index_d : Expr
The fragment_d index.
fragment_a : Var
The wmma fragment_a.
index_a : Expr
The fragment_a index.
fragment_b : Var
The wmma fragment_b.
index_b : Expr
The fragment_b index.
fragment_c : Var
The wmma fragment_c.
index_c : Expr
The fragment_c index.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_mma_sync",
fragment_d,
index_d,
fragment_a,
index_a,
fragment_b,
index_b,
fragment_c,
index_c,
)


def tvm_bmma_sync(
fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
):
"""TVM intrinsic for tensor core bmma_sync operators
Parameters
----------
fragment_d : Var
The bwmma fragment_d.
index_d : Expr
The fragment_d index.
fragment_a : Var
The bwmma fragment_a.
index_a : Expr
The fragment_a index.
fragment_b : Var
The bwmma fragment_b.
index_b : Expr
The fragment_b index.
fragment_c : Var
The bwmma fragment_c.
index_c : Expr
The fragment_c index.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_bmma_sync",
fragment_d,
index_d,
fragment_a,
index_a,
fragment_b,
index_b,
fragment_c,
index_c,
)


def tvm_fill_fragment(fragment, m, n, k, index, value):
"""TVM intrinsic for tensor core fill_fragment operators
Parameters
----------
fragment : Var
The wmma fragment
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
value : Expr
The value to be filled in fragment.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_fill_fragment",
fragment,
m,
n,
k,
index,
value,
)


def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core store operators
Parameters
----------
fragment : Var
The wmma fragment.
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
buffer_ptr : Expr
The fragment buffer pointer.
stride : Expr
The fragment stride.
layout : Literal["row_major", "column_major"]
The fragment layout.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_store_matrix_sync",
fragment,
m,
n,
k,
index,
buffer_ptr,
stride,
layout,
)


def vectorlow(dtype, vec):
"""Get the low level half of the vector
Expand Down
43 changes: 43 additions & 0 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,44 @@ def test_tir_op_tvm_throw_last_error():
assert expr.op.name == "tir.tvm_throw_last_error"


def test_tir_op_tvm_load_matrix_sync():
buffer = tir.decl_buffer((16, 16), "float32")
x = tir.Var("x", "handle")
expr = tir.tvm_load_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major")
assert expr.op.name == "tir.tvm_load_matrix_sync"


def test_tir_op_tvm_store_matrix_sync():
buffer = tir.decl_buffer((16, 16), "float32")
x = tir.Var("x", "handle")
expr = tir.tvm_store_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major")
assert expr.op.name == "tir.tvm_store_matrix_sync"


def test_tir_op_tvm_mma_sync():
buffer_0 = tir.decl_buffer((16, 16), "float32")
buffer_1 = tir.decl_buffer((16, 16), "float32")
buffer_2 = tir.decl_buffer((16, 16), "float32")
buffer_3 = tir.decl_buffer((16, 16), "float32")
expr = tir.tvm_mma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0)
assert expr.op.name == "tir.tvm_mma_sync"


def test_tir_op_tvm_bmma_sync():
buffer_0 = tir.decl_buffer((16, 16), "float32")
buffer_1 = tir.decl_buffer((16, 16), "float32")
buffer_2 = tir.decl_buffer((16, 16), "float32")
buffer_3 = tir.decl_buffer((16, 16), "float32")
expr = tir.tvm_bmma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0)
assert expr.op.name == "tir.tvm_bmma_sync"


def test_tir_op_tvm_fill_fragment():
buffer = tir.decl_buffer((16, 16), "float32")
expr = tir.tvm_fill_fragment(buffer.data, 16, 16, 16, 0, 0)
assert expr.op.name == "tir.tvm_fill_fragment"


def test_tir_op_vectorlow():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
Expand Down Expand Up @@ -165,6 +203,11 @@ def test_tir_op_TVMBackendFreeWorkspace():
test_tir_op_type_annotation()
test_tir_op_tvm_access_ptr()
test_tir_op_tvm_throw_last_error()
test_tir_op_tvm_load_matrix_sync(),
test_tir_op_tvm_store_matrix_sync(),
test_tir_op_tvm_mma_sync(),
test_tir_op_tvm_bmma_sync(),
test_tir_op_tvm_fill_fragment(),
test_tir_op_vectorlow()
test_tir_op_vectorhigh()
test_tir_op_vectorcombine()
Expand Down

0 comments on commit b387384

Please sign in to comment.