Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 848de63 commit 4b13b85
Showing 1 changed file with 67 additions and 77 deletions.
144 changes: 67 additions & 77 deletions tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
import numpy as np


def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)


def shared_16x32_to_ldmatrix_32x16_layout(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4


@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
return tvm.runtime.convert([thread_id, local_id])


@T.prim_func
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
A_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared")
Expand All @@ -21,10 +37,9 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
with T.block("A_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_shared[v0, v1])
T.writes(A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4])
A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[
v0, v1
]
thread_id, local_id = shared_16x32_to_ldmatrix_32x16_layout(v0, v1)
T.writes(A_warp[thread_id, local_id])
A_warp[thread_id, local_id] = A_shared[v0, v1]


@T.prim_func
Expand Down Expand Up @@ -74,10 +89,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
with T.block("B_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B_shared[v0, v1])
T.writes(B_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4])
B_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = B_shared[
v0, v1
]
thread_id, local_id = shared_16x32_to_ldmatrix_32x16_layout(v0, v1)
T.writes(B_warp[thread_id, local_id])
B_warp[thread_id, local_id] = B_shared[v0, v1]


@T.prim_func
Expand Down Expand Up @@ -126,10 +140,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j, k in T.grid(16, 16, 32):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2], A[i % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4], B[j % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + j % 16 // 8 * 4 + k % 4])
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] + T.cast(A[i % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4], "int32") * T.cast(B[j % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + j % 16 // 8 * 4 + k % 4], "int32")

thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x32_to_ldmatrix_32x16_layout(i, k)
thread_id_B, local_id_B = shared_16x32_to_ldmatrix_32x16_layout(j, k)

T.reads(
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])
C[thread_id_C, local_id_C] += T.cast(A[thread_id_A, local_id_A], "int32") * T.cast(
B[thread_id_B, local_id_B], "int32"
)

@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
Expand Down Expand Up @@ -190,14 +213,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
for ax1_0, i0, i1 in T.grid(2, 32, 4):
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)

T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
v0, v1 = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
C[v0, v1] = C_warp[thread_id, local_id]


@T.prim_func
Expand Down Expand Up @@ -230,21 +252,13 @@ def mma_fill_desc(a: T.handle) -> None:
with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(32, 8):
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
i_init, j_init = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
T.reads()
T.writes(
C_warp[
i_init % 8 * 4 + j_init % 8 // 2,
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
]
)
C_warp[
i_init % 8 * 4 + j_init % 8 // 2,
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
] = T.int32(0)
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = T.int32(0)


@T.prim_func
Expand Down Expand Up @@ -370,8 +384,6 @@ def fetch_to_shared(block, idx, ndim, vec=False):
A_sh = fetch_to_shared(block_outer, 0, 2, True)
B_sh = fetch_to_shared(block_outer, 1, 2, True)

loop = sch.get_loops(block_outer)[-1]

A_warp = sch.cache_read(block_outer, 0, "warp")
B_warp = sch.cache_read(block_outer, 1, "warp")

Expand All @@ -386,7 +398,8 @@ def fetch_to_shared(block, idx, ndim, vec=False):
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height, width):
i, j = sch.get_loops(block_read)[-2:]
Expand All @@ -395,57 +408,34 @@ def tile_wmma_fragment(block_read, height, width):
sch.reorder(i0, j0, i1, j1)
return i1

def shared_16x16_to_ldmatrix_32x8_layout(i, j):
i_0 = i // 16
j_0 = j // 16

i = i % 16
j = j % 16

thread_id = 4 * (i % 8) + (j % 8) // 2
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
loop_a = tile_wmma_fragment(A_warp, 16, 32)
loop_b = tile_wmma_fragment(B_warp, 16, 32)

def shared_16x32_to_ldmatrix_32x16_layout(i, j):
i_0 = i // 16
j_0 = j // 32

i = i % 16
j = j % 32
def index_map_A_B(i, j):
return (
i // 16,
j // 32,
*shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
)

thread_id = 4 * (i % 8) + (j % 16) // 4
return i_0, j_0, thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4

loop_a = tile_wmma_fragment(A_warp, 16, 32)
loop_b = tile_wmma_fragment(B_warp, 16, 32)
def index_map_C(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
)

sch.transform_layout(A_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout)
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout)
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(A_warp, 0, "write", index_map_A_B)
sch.transform_layout(B_warp, 0, "write", index_map_A_B)
sch.transform_layout(C_warp, 0, "read", index_map_C)

sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")

mma_loop = sch.get_loops(block_inner)[-3]
sch.tensorize(mma_loop, "mma_sync")

block_init_c = sch.get_block("C_init")
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
sch.reorder(f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(fused_1, "mma_fill")

warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
sch.reorder(outer, f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(outer, "mma_store")
# print(sch.mod.script())
# return
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")


ir_module = tvm.IRModule({"main": workload})
Expand Down

0 comments on commit 4b13b85

Please sign in to comment.