Skip to content

Commit

Permalink
fix pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Jan 25, 2024
1 parent c09b4a2 commit 24e1527
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions python/mlc_chat/compiler_pass/fuse_add_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
# mypy: disable-error-code="attr-defined,valid-type"
# pylint: disable=too-many-locals,invalid-name

TX = 1024


def _get_add_rms_norm_decode(hidden_size: int, eps: float):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
eps = T.float32(eps)
add_local_size = hidden_size // TX

@T.prim_func(private=True)
def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle):
Expand All @@ -22,49 +25,49 @@ def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd:
C = T.match_buffer(pC, (hidden_size,), "float16")
O = T.match_buffer(pO, (batch_size, 1, hidden_size), "float16")
add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), "float16")
add_local = T.alloc_buffer((4,), "float16", scope="local")
add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local")
sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
sum_local = T.alloc_buffer((TX, batch_size, 1), scope="local")
for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
for v_tx in T.thread_binding(
1024,
TX,
thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
):
for i in range(4):
for i in range(add_local_size):
with T.block("T_add"):
bx = T.axis.spatial(batch_size, v_bx)
h = T.axis.spatial(hidden_size, i * 1024 + v_tx)
add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
h = T.axis.spatial(hidden_size, i * TX + v_tx)
add_local[h // TX] = A[bx, 0, h] + B[bx, 0, h]
with T.block("T_write_back"):
bx = T.axis.spatial(batch_size, v_bx)
v_ax1 = T.axis.spatial(1, 0)
h = T.axis.spatial(hidden_size, i * 1024 + v_tx)
add[bx, v_ax1, h] = add_local[h // 1024]
h = T.axis.spatial(hidden_size, i * TX + v_tx)
add[bx, v_ax1, h] = add_local[h // TX]
with T.block("T_multiply_red_rf_init"):
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
sum_local[tx, bx, 0] = T.float32(0)
for v_i, _j in T.grid(4, 1):
for v_i, _j in T.grid(add_local_size, 1):
with T.block("T_multiply_red_rf_update"):
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i])
for _j in range(1):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_multiply_red"):
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
T.reads(sum_local[tx, bx, 0])
T.writes(sum_shared[bx, 0])
with T.init():
sum_shared[bx, 0] = T.float32(0)
sum_shared[bx, 0] += sum_local[tx, bx, 0]
for i in range(4):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
for i in range(add_local_size):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_cast_2"):
bx = T.axis.spatial(batch_size, v_bx)
h = T.axis.spatial(hidden_size, i * 1024 + v_tx_2)
h = T.axis.spatial(hidden_size, i * TX + v_tx_2)
O[bx, 0, h] = T.float16(
T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps)
* T.float32(add_local[h // 1024])
* T.float32(add_local[h // TX])
* T.float32(C[h])
)

Expand All @@ -74,6 +77,7 @@ def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd:
def _get_add_rms_norm_prefill(hidden_size: int, eps: float):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
eps = T.float32(eps)
add_local_size = hidden_size // TX

@T.prim_func(private=True)
def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle):
Expand All @@ -84,46 +88,46 @@ def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd
C = T.match_buffer(pC, (hidden_size,), "float16")
O = T.match_buffer(pO, (1, seq_len, hidden_size), "float16")
add = T.match_buffer(pAdd, (1, seq_len, hidden_size), "float16")
add_local = T.alloc_buffer((4,), "float16", scope="local")
add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local")
sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
sum_local = T.alloc_buffer((TX, 1, seq_len), scope="local")
for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
for v_tx in T.thread_binding(
1024,
TX,
thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
):
for v_i in range(4):
for v_i in range(add_local_size):
with T.block("T_add"):
bx = T.axis.spatial(seq_len, v_bx)
h = T.axis.spatial(hidden_size, v_i * 1024 + v_tx)
add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
h = T.axis.spatial(hidden_size, v_i * TX + v_tx)
add_local[h // TX] = A[0, bx, h] + B[0, bx, h]
with T.block("T_write_back"):
bx = T.axis.spatial(seq_len, v_bx)
h = T.axis.spatial(hidden_size, v_i * 1024 + v_tx)
add[0, bx, h] = add_local[h // 1024]
h = T.axis.spatial(hidden_size, v_i * TX + v_tx)
add[0, bx, h] = add_local[h // TX]
with T.block("T_multiply_red_rf_init"):
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
sum_local[tx, 0, bx] = T.float32(0)
for v_i, _j in T.grid(4, 1):
for v_i, _j in T.grid(add_local_size, 1):
with T.block("T_multiply_red_rf_update"):
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i])
for _j in range(1):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_multiply_red"):
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
with T.init():
sum_shared[0, bx] = T.float32(0)
sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
for v_i in range(4):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
for v_i in range(add_local_size):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_cast_2"):
bx = T.axis.spatial(seq_len, v_bx)
v1 = T.axis.spatial(hidden_size, v_i * 1024 + v_tx_2)
v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2)
O[0, bx, v1] = T.float16(
T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps)
* T.float32(add_local[v1 // 1024])
* T.float32(add_local[v1 // TX])
* T.float32(C[v1])
)

Expand Down Expand Up @@ -157,6 +161,8 @@ def rewriter(matchings, bindings):

if all(gv.name_hint != func_name for gv in mod.functions):
h = int(h)
if h % TX != 0:
return {}
if n == 1:
func = _get_add_rms_norm_prefill(h, eps)
else:
Expand Down

0 comments on commit 24e1527

Please sign in to comment.