Skip to content

Commit

Permalink
[SLM] Fuse Add and RMSNorm
Browse files Browse the repository at this point in the history
This PR adds a fusion pass that applies to "binary add" and "RMSNorm".
This is a temporary workaround that allows us to fuse "add" into RMSNorm
once it is not fused into GEMM epilogue.
  • Loading branch information
junrushao committed Jan 25, 2024
1 parent b9605ab commit c09b4a2
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 3 deletions.
193 changes: 193 additions & 0 deletions python/mlc_chat/compiler_pass/fuse_add_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""A compiler pass that fuses add + rms_norm."""
import tvm
from tvm import relax
from tvm.relax.dpl import PatternContext, rewrite_bindings
from tvm.relax.dpl.pattern import is_op, wildcard
from tvm.script import tir as T

# mypy: disable-error-code="attr-defined,valid-type"
# pylint: disable=too-many-locals,invalid-name


def _get_add_rms_norm_decode(hidden_size: int, eps: float):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
eps = T.float32(eps)

@T.prim_func(private=True)
def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle):
T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
batch_size = T.int32()
A = T.match_buffer(pA, (batch_size, 1, hidden_size), "float16")
B = T.match_buffer(pB, (batch_size, 1, hidden_size), "float16")
C = T.match_buffer(pC, (hidden_size,), "float16")
O = T.match_buffer(pO, (batch_size, 1, hidden_size), "float16")
add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), "float16")
add_local = T.alloc_buffer((4,), "float16", scope="local")
sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
for v_tx in T.thread_binding(
1024,
thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
):
for i in range(4):
with T.block("T_add"):
bx = T.axis.spatial(batch_size, v_bx)
h = T.axis.spatial(hidden_size, i * 1024 + v_tx)
add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
with T.block("T_write_back"):
bx = T.axis.spatial(batch_size, v_bx)
v_ax1 = T.axis.spatial(1, 0)
h = T.axis.spatial(hidden_size, i * 1024 + v_tx)
add[bx, v_ax1, h] = add_local[h // 1024]
with T.block("T_multiply_red_rf_init"):
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
sum_local[tx, bx, 0] = T.float32(0)
for v_i, _j in T.grid(4, 1):
with T.block("T_multiply_red_rf_update"):
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i])
for _j in range(1):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_multiply_red"):
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
T.reads(sum_local[tx, bx, 0])
T.writes(sum_shared[bx, 0])
with T.init():
sum_shared[bx, 0] = T.float32(0)
sum_shared[bx, 0] += sum_local[tx, bx, 0]
for i in range(4):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_cast_2"):
bx = T.axis.spatial(batch_size, v_bx)
h = T.axis.spatial(hidden_size, i * 1024 + v_tx_2)
O[bx, 0, h] = T.float16(
T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps)
* T.float32(add_local[h // 1024])
* T.float32(C[h])
)

return decode_add_rms


def _get_add_rms_norm_prefill(hidden_size: int, eps: float):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
eps = T.float32(eps)

@T.prim_func(private=True)
def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle):
T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
seq_len = T.int32()
A = T.match_buffer(pA, (1, seq_len, hidden_size), "float16")
B = T.match_buffer(pB, (1, seq_len, hidden_size), "float16")
C = T.match_buffer(pC, (hidden_size,), "float16")
O = T.match_buffer(pO, (1, seq_len, hidden_size), "float16")
add = T.match_buffer(pAdd, (1, seq_len, hidden_size), "float16")
add_local = T.alloc_buffer((4,), "float16", scope="local")
sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
for v_tx in T.thread_binding(
1024,
thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
):
for v_i in range(4):
with T.block("T_add"):
bx = T.axis.spatial(seq_len, v_bx)
h = T.axis.spatial(hidden_size, v_i * 1024 + v_tx)
add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
with T.block("T_write_back"):
bx = T.axis.spatial(seq_len, v_bx)
h = T.axis.spatial(hidden_size, v_i * 1024 + v_tx)
add[0, bx, h] = add_local[h // 1024]
with T.block("T_multiply_red_rf_init"):
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
sum_local[tx, 0, bx] = T.float32(0)
for v_i, _j in T.grid(4, 1):
with T.block("T_multiply_red_rf_update"):
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i])
for _j in range(1):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_multiply_red"):
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
with T.init():
sum_shared[0, bx] = T.float32(0)
sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
for v_i in range(4):
for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_cast_2"):
bx = T.axis.spatial(seq_len, v_bx)
v1 = T.axis.spatial(hidden_size, v_i * 1024 + v_tx_2)
O[0, bx, v1] = T.float16(
T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps)
* T.float32(add_local[v1 // 1024])
* T.float32(C[v1])
)

return prefill_add_rms


@tvm.transform.module_pass(opt_level=0, name="FuseAddRMSNorm")
class FuseAddRMSNorm: # pylint: disable=too-few-public-methods
"""A compiler pass that fuses add + rms_norm."""

def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule:
"""IRModule-level transformation."""
with PatternContext() as ctx:
pat_x1 = wildcard()
pat_x2 = wildcard()
pat_y = is_op("relax.add")(pat_x1, pat_x2)
pat_w = wildcard()
pat_o = is_op("relax.nn.rms_norm")(pat_y, pat_w)

def rewriter(matchings, bindings):
x1 = matchings[pat_x1]
x2 = matchings[pat_x2]
weight = matchings[pat_w]
y = matchings[pat_y]
o = matchings[pat_o]
eps = bindings[o].attrs.epsilon
if x1.struct_info.dtype != "float16":
return {}
n, _, h = x1.struct_info.shape
func_name = "fuse_add_norm_prefill" if n == 1 else "fuse_add_norm_decode"

if all(gv.name_hint != func_name for gv in mod.functions):
h = int(h)
if n == 1:
func = _get_add_rms_norm_prefill(h, eps)
else:
func = _get_add_rms_norm_decode(h, eps)
mod[func_name] = func
gvar = mod.get_global_var(func_name)
relax.expr._update_struct_info( # pylint: disable=protected-access
gvar,
relax.ObjectStructInfo(),
)
else:
gvar = mod.get_global_var(func_name)
o_y_tuple = relax.call_tir(
gvar,
[x1, x2, weight],
out_sinfo=[x1.struct_info, x1.struct_info],
)
return {
o: relax.TupleGetItem(o_y_tuple, 0),
y: relax.TupleGetItem(o_y_tuple, 1),
}

new_mod = {}
for gvar, func in mod.functions.items():
if isinstance(func, relax.Function):
func = rewrite_bindings(ctx, rewriter, func)
new_mod[gvar] = func

for gvar, func in mod.functions.items():
if isinstance(func, tvm.tir.PrimFunc) and gvar not in new_mod:
new_mod[gvar] = func

new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos)
return new_mod
2 changes: 2 additions & 0 deletions python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .clean_up_tir_attrs import CleanUpTIRAttrs
from .cublas_dispatch import CublasDispatch
from .estimate_memory_usage import AttachMetadataWithMemoryUsage
from .fuse_add_norm import FuseAddRMSNorm
from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise
from .fuse_dequantize_take import FuseDequantizeTake
from .fuse_dequantize_transpose import FuseDequantizeTranspose
Expand Down Expand Up @@ -92,6 +93,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
FuseFTDequantizeEpilogue(),
FuseDequantizeTranspose(),
CublasDispatch() if cublas_gemm else tvm.transform.Sequential([]),
FuseAddRMSNorm(),
FuseTransposeMatmul(),
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, lay

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return op.ccl_allreduce(out, "sum") + residual
return out + residual


Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/mistral/mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def forward( # pylint: disable=too-many-arguments

def _apply_residual(out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return op.ccl_allreduce(out, "sum") + residual
return out + residual

out = self.self_attn(
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/mixtral/mixtral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def forward( # pylint: disable=too-many-arguments

def _apply_residual(out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return op.ccl_allreduce(out, "sum") + residual
return out + residual

out = self.self_attn(
Expand Down

0 comments on commit c09b4a2

Please sign in to comment.