Skip to content

Commit

Permalink
[TIR][DLight] Enable SimdGroup op for Metal
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Jun 23, 2024
1 parent e6bfaf8 commit 238366e
Show file tree
Hide file tree
Showing 11 changed files with 1,124 additions and 7 deletions.
44 changes: 43 additions & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ TVM_DLL const Op& create_barriers();
TVM_DLL const Op& mma_store();

/*!
* \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
* \brief tvm intrinsic for zero-initializing an MMA accumulation register.
* For example, if each thread in a warp of size 32 has 8 elements from the A matrix in
* m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its
* 4 accumulation registers.
Expand All @@ -758,6 +758,48 @@ TVM_DLL const Op& mma_store();
*/
TVM_DLL const Op& mma_fill();

// Metal SimdGroup matrix intrinsics

/*!
* \brief tvm intrinsic for initializing and simdgroup with given value.
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value,
* int col = 8, int row = 8);
*/
TVM_DLL const Op& make_filled_simdgroup_matrix();

/*!
* \brief tvm intrinsic for loading data from device memory or threadgroup memory to simdgroup.
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
int col = 8, int row = 8, bool transpose_matrix = false);
*/
TVM_DLL const Op& simdgroup_load();

/*!
* \brief tvm intrinsic for storing data from simdgroup to device memory or threadgroup memory.
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
* int col = 8, int row = 8, bool transpose_matrix = false);
*/
TVM_DLL const Op& simdgroup_store();

/*!
* \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a,
* Var b, PrimExpr index_b, Var c, PrimExpr index_c);
*/
TVM_DLL const Op& simdgroup_multiply_accumulate();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
145 changes: 145 additions & 0 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int:
return int(sm_version) if sm_version.isdigit() else -1


class MetalMatmul(GPUScheduleRule):
"""
The schedule rule for Metal matmul computation.
"""

def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
from tvm.tir.tensor_intrin.metal import ( # pylint: disable=import-outside-toplevel
get_simdgroup_intrin_group,
)

if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

reduction_blocks = get_reduction_blocks(sch, blocks)
if reduction_blocks is None:
return None

main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

# Step 0. Configs
block_size_x: int = 16
block_size_y: int = 16
block_size_k: int = 32
micro_size: int = 8
warp_size: int = 32
ty_len: int = 1
tz_len: int = 4
vector_size: int = 4

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 2. Padding for dynamic shape kernels
sch.pad_einsum(
main_block,
[
1,
ty_len * block_size_x,
tz_len * block_size_y,
block_size_k,
],
)

# Step 3. Schedule matmul to use simdgroup intrinsics
batch, i, j, k = sch.get_loops(main_block)
bx, ty, i0, i1 = sch.split(i, [None, ty_len, block_size_x // micro_size, micro_size])
by, tz, j0, j1 = sch.split(j, [None, tz_len, block_size_y // micro_size, micro_size])
k0, k1, k2 = sch.split(k, [None, block_size_k // micro_size, micro_size])
sch.reorder(bx, by, ty, tz, k0, k1, i0, j0, i1, j1, k2)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
sch.bind(batch, "blockIdx.z")
sch.bind(ty, "threadIdx.y")
sch.bind(tz, "threadIdx.z")

def fetch_to_shared(block, idx):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, k0, preserve_unit_loops=True)
fused = sch.fuse(*sch.get_loops(block_read)[-2:])
_, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size])

sch.bind(_tz, "threadIdx.z")
sch.bind(_ty, "threadIdx.y")
sch.bind(_tx, "threadIdx.x")
sch.vectorize(vec)

return block_read

a_g2s = fetch_to_shared(main_block, 0)
b_g2s = fetch_to_shared(main_block, 1)

auto_inline_producers(sch, a_g2s)
auto_inline_producers(sch, b_g2s)

# create read cache to load matrix from shared memory to wmma fragments
A_simdgroup = sch.cache_read(main_block, 0, "metal.simdgroup")
B_simdgroup = sch.cache_read(main_block, 1, "metal.simdgroup")
sch.compute_at(A_simdgroup, k1)
sch.compute_at(B_simdgroup, k1)

C_simd2s = sch.cache_write(main_block, 0, "metal.simdgroup")
C_s2g = sch.cache_write(C_simd2s, 0, "shared")
sch.reverse_compute_at(C_simd2s, tz, preserve_unit_loops=True)
sch.reverse_compute_at(C_s2g, by, preserve_unit_loops=True)

intrin_group = get_simdgroup_intrin_group(
load_scope="shared",
store_scope="shared",
dtype="float16",
trans_a=False,
trans_b=True,
)
sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j, i))

def tensorize_block(block: tir.schedule.BlockRV, intrin: str):
*_, i, j = sch.get_loops(block)
io, ii = sch.split(i, [None, micro_size])
jo, ji = sch.split(j, [None, micro_size])
sch.reorder(io, jo, ii, ji)
sch.tensorize(ii, intrin)

C_init = sch.decompose_reduction(main_block, k0)
tensorize_block(A_simdgroup, intrin_group["load_a"])
tensorize_block(B_simdgroup, intrin_group["load_b"])
tensorize_block(C_simd2s, intrin_group["store"])
tensorize_block(C_init, intrin_group["init"])

*_, i, j, k = sch.get_loops(main_block)
sch.tensorize(i, intrin_group["compute"])

auto_inline_consumer_chain(sch, C_s2g)
fused = sch.fuse(*sch.get_loops(C_s2g)[-2:])
_, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size])
sch.bind(_tz, "threadIdx.z")
sch.bind(_ty, "threadIdx.y")
sch.bind(_tx, "threadIdx.x")
sch.vectorize(vec)

return sch


class MatmulTensorization(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
Expand Down Expand Up @@ -848,6 +988,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
tensorize_sch = MatmulTensorization().apply(func, target, _)
if tensorize_sch is not None:
return tensorize_sch
elif target.kind.name == "metal":
try:
return MetalMatmul().apply(func, target, _)
except: # pylint: disable=bare-except
pass

# Step 2. Get schedule config.
config = self.get_configs(target)
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,10 @@ def wrapped(*args, **kwargs):
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
make_filled_simdgroup_matrix = _op_wrapper(_tir_op.make_filled_simdgroup_matrix)
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
Expand Down Expand Up @@ -2177,6 +2181,10 @@ def wrapped(*args, **kwargs):
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
"make_filled_simdgroup_matrix",
"simdgroup_load",
"simdgroup_store",
"simdgroup_multiply_accumulate",
"create_barriers",
"mma_store",
"mma_fill",
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
ptx_wait_barrier,
create_barriers,
)
from .op import (
make_filled_simdgroup_matrix,
simdgroup_load,
simdgroup_multiply_accumulate,
simdgroup_store,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
Loading

0 comments on commit 238366e

Please sign in to comment.