From 24e152762a39f1185a30777a73343dbaf5937464 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Thu, 25 Jan 2024 14:57:53 -0500 Subject: [PATCH] fix pass --- .../mlc_chat/compiler_pass/fuse_add_norm.py | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/python/mlc_chat/compiler_pass/fuse_add_norm.py b/python/mlc_chat/compiler_pass/fuse_add_norm.py index a8b12c9f53..c477d6b297 100644 --- a/python/mlc_chat/compiler_pass/fuse_add_norm.py +++ b/python/mlc_chat/compiler_pass/fuse_add_norm.py @@ -8,10 +8,13 @@ # mypy: disable-error-code="attr-defined,valid-type" # pylint: disable=too-many-locals,invalid-name +TX = 1024 + def _get_add_rms_norm_decode(hidden_size: int, eps: float): inv_hidden_size = T.float32(1.0 / float(hidden_size)) eps = T.float32(eps) + add_local_size = hidden_size // TX @T.prim_func(private=True) def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle): @@ -22,34 +25,34 @@ def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: C = T.match_buffer(pC, (hidden_size,), "float16") O = T.match_buffer(pO, (batch_size, 1, hidden_size), "float16") add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), "float16") - add_local = T.alloc_buffer((4,), "float16", scope="local") + add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local") sum_shared = T.alloc_buffer((batch_size, 1), scope="shared") - sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local") + sum_local = T.alloc_buffer((TX, batch_size, 1), scope="local") for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"): for v_tx in T.thread_binding( - 1024, + TX, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}, ): - for i in range(4): + for i in range(add_local_size): with T.block("T_add"): bx = T.axis.spatial(batch_size, v_bx) - h = T.axis.spatial(hidden_size, i * 1024 + v_tx) - add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h] + h = T.axis.spatial(hidden_size, i * TX + v_tx) + add_local[h // TX] = A[bx, 0, h] + B[bx, 0, h] with T.block("T_write_back"): bx = T.axis.spatial(batch_size, v_bx) v_ax1 = T.axis.spatial(1, 0) - h = T.axis.spatial(hidden_size, i * 1024 + v_tx) - add[bx, v_ax1, h] = add_local[h // 1024] + h = T.axis.spatial(hidden_size, i * TX + v_tx) + add[bx, v_ax1, h] = add_local[h // TX] with T.block("T_multiply_red_rf_init"): tx, bx = T.axis.remap("SS", [v_tx, v_bx]) sum_local[tx, bx, 0] = T.float32(0) - for v_i, _j in T.grid(4, 1): + for v_i, _j in T.grid(add_local_size, 1): with T.block("T_multiply_red_rf_update"): tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i]) sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i]) for _j in range(1): - for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): with T.block("T_multiply_red"): tx, bx = T.axis.remap("RS", [v_tx_2, v_bx]) T.reads(sum_local[tx, bx, 0]) @@ -57,14 +60,14 @@ def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: with T.init(): sum_shared[bx, 0] = T.float32(0) sum_shared[bx, 0] += sum_local[tx, bx, 0] - for i in range(4): - for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"): + for i in range(add_local_size): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): with T.block("T_cast_2"): bx = T.axis.spatial(batch_size, v_bx) - h = T.axis.spatial(hidden_size, i * 1024 + v_tx_2) + h = T.axis.spatial(hidden_size, i * TX + v_tx_2) O[bx, 0, h] = T.float16( T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps) - * T.float32(add_local[h // 1024]) + * T.float32(add_local[h // TX]) * T.float32(C[h]) ) @@ -74,6 +77,7 @@ def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: def _get_add_rms_norm_prefill(hidden_size: int, eps: float): inv_hidden_size = T.float32(1.0 / float(hidden_size)) eps = T.float32(eps) + add_local_size = hidden_size // TX @T.prim_func(private=True) def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle): @@ -84,46 +88,46 @@ def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd C = T.match_buffer(pC, (hidden_size,), "float16") O = T.match_buffer(pO, (1, seq_len, hidden_size), "float16") add = T.match_buffer(pAdd, (1, seq_len, hidden_size), "float16") - add_local = T.alloc_buffer((4,), "float16", scope="local") + add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local") sum_shared = T.alloc_buffer((1, seq_len), scope="shared") - sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local") + sum_local = T.alloc_buffer((TX, 1, seq_len), scope="local") for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"): for v_tx in T.thread_binding( - 1024, + TX, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}, ): - for v_i in range(4): + for v_i in range(add_local_size): with T.block("T_add"): bx = T.axis.spatial(seq_len, v_bx) - h = T.axis.spatial(hidden_size, v_i * 1024 + v_tx) - add_local[h // 1024] = A[0, bx, h] + B[0, bx, h] + h = T.axis.spatial(hidden_size, v_i * TX + v_tx) + add_local[h // TX] = A[0, bx, h] + B[0, bx, h] with T.block("T_write_back"): bx = T.axis.spatial(seq_len, v_bx) - h = T.axis.spatial(hidden_size, v_i * 1024 + v_tx) - add[0, bx, h] = add_local[h // 1024] + h = T.axis.spatial(hidden_size, v_i * TX + v_tx) + add[0, bx, h] = add_local[h // TX] with T.block("T_multiply_red_rf_init"): tx, bx = T.axis.remap("SS", [v_tx, v_bx]) sum_local[tx, 0, bx] = T.float32(0) - for v_i, _j in T.grid(4, 1): + for v_i, _j in T.grid(add_local_size, 1): with T.block("T_multiply_red_rf_update"): tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i]) sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i]) for _j in range(1): - for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): with T.block("T_multiply_red"): tx, bx = T.axis.remap("RS", [v_tx_2, v_bx]) with T.init(): sum_shared[0, bx] = T.float32(0) sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx] - for v_i in range(4): - for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"): + for v_i in range(add_local_size): + for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"): with T.block("T_cast_2"): bx = T.axis.spatial(seq_len, v_bx) - v1 = T.axis.spatial(hidden_size, v_i * 1024 + v_tx_2) + v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2) O[0, bx, v1] = T.float16( T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps) - * T.float32(add_local[v1 // 1024]) + * T.float32(add_local[v1 // TX]) * T.float32(C[v1]) ) @@ -157,6 +161,8 @@ def rewriter(matchings, bindings): if all(gv.name_hint != func_name for gv in mod.functions): h = int(h) + if h % TX != 0: + return {} if n == 1: func = _get_add_rms_norm_prefill(h, eps) else: