Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLM] Fuse Add and RMSNorm #1627

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions python/mlc_chat/compiler_pass/fuse_add_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""A compiler pass that fuses add + rms_norm."""
import tvm
from tvm import relax
from tvm.relax.dpl import PatternContext, rewrite_bindings
from tvm.relax.dpl.pattern import is_op, wildcard
from tvm.script import tir as T

# mypy: disable-error-code="attr-defined,valid-type"
# pylint: disable=too-many-locals,invalid-name

TX = 1024


def _get_add_rms_norm_decode(hidden_size: int, eps: float):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
eps = T.float32(eps)
add_local_size = hidden_size // TX

@T.prim_func(private=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the fused add-rmsnorm operator could be expressed by TE and scheduled by Dlight easily

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not that easy. The tricky point is that add_rmsnorm function has 2 outputs: both results of add and rms_norm, which makes compute_inline/compute_at/reverse_compute_at all fail in such case. I have to write scheduled TIR to work around.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for the elaboration!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the implication on performance? If the manual schedule can't generalize to all cases we can try supporting such pattern in cublas fusion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I heard from @MasterJH5574 that cublas fusion cannot successfully fuse matmul and divide_add now, so I create this small pass to unblock our effort on mlc serve perf profiling. In the long term, surely this can be replaced by better cublas fusion, but it doesn't hurt to work as a fallback or as a target to compare.

def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle):
T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
batch_size = T.int32()
A = T.match_buffer(pA, (batch_size, 1, hidden_size), "float16")
B = T.match_buffer(pB, (batch_size, 1, hidden_size), "float16")
C = T.match_buffer(pC, (hidden_size,), "float16")
O = T.match_buffer(pO, (batch_size, 1, hidden_size), "float16")
add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), "float16")
add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local")
sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
sum_local = T.alloc_buffer((TX, batch_size, 1), scope="local")
for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
for v_tx in T.thread_binding(
TX,
thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
):
for i in range(add_local_size):
with T.block("T_add"):
bx = T.axis.spatial(batch_size, v_bx)
h = T.axis.spatial(hidden_size, i * TX + v_tx)
add_local[h // TX] = A[bx, 0, h] + B[bx, 0, h]
with T.block("T_write_back"):
bx = T.axis.spatial(batch_size, v_bx)
v_ax1 = T.axis.spatial(1, 0)
h = T.axis.spatial(hidden_size, i * TX + v_tx)
add[bx, v_ax1, h] = add_local[h // TX]
with T.block("T_multiply_red_rf_init"):
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
sum_local[tx, bx, 0] = T.float32(0)
for v_i, _j in T.grid(add_local_size, 1):
with T.block("T_multiply_red_rf_update"):
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i])
for _j in range(1):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_multiply_red"):
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
T.reads(sum_local[tx, bx, 0])
T.writes(sum_shared[bx, 0])
with T.init():
sum_shared[bx, 0] = T.float32(0)
sum_shared[bx, 0] += sum_local[tx, bx, 0]
for i in range(add_local_size):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_cast_2"):
bx = T.axis.spatial(batch_size, v_bx)
h = T.axis.spatial(hidden_size, i * TX + v_tx_2)
O[bx, 0, h] = T.float16(
T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps)
* T.float32(add_local[h // TX])
* T.float32(C[h])
)

return decode_add_rms


def _get_add_rms_norm_prefill(hidden_size: int, eps: float):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
eps = T.float32(eps)
add_local_size = hidden_size // TX

@T.prim_func(private=True)
def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle):
T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
seq_len = T.int32()
A = T.match_buffer(pA, (1, seq_len, hidden_size), "float16")
B = T.match_buffer(pB, (1, seq_len, hidden_size), "float16")
C = T.match_buffer(pC, (hidden_size,), "float16")
O = T.match_buffer(pO, (1, seq_len, hidden_size), "float16")
add = T.match_buffer(pAdd, (1, seq_len, hidden_size), "float16")
add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local")
sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
sum_local = T.alloc_buffer((TX, 1, seq_len), scope="local")
for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
for v_tx in T.thread_binding(
TX,
thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
):
for v_i in range(add_local_size):
with T.block("T_add"):
bx = T.axis.spatial(seq_len, v_bx)
h = T.axis.spatial(hidden_size, v_i * TX + v_tx)
add_local[h // TX] = A[0, bx, h] + B[0, bx, h]
with T.block("T_write_back"):
bx = T.axis.spatial(seq_len, v_bx)
h = T.axis.spatial(hidden_size, v_i * TX + v_tx)
add[0, bx, h] = add_local[h // TX]
with T.block("T_multiply_red_rf_init"):
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
sum_local[tx, 0, bx] = T.float32(0)
for v_i, _j in T.grid(add_local_size, 1):
with T.block("T_multiply_red_rf_update"):
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i])
for _j in range(1):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_multiply_red"):
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
with T.init():
sum_shared[0, bx] = T.float32(0)
sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
for v_i in range(add_local_size):
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
with T.block("T_cast_2"):
bx = T.axis.spatial(seq_len, v_bx)
v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2)
O[0, bx, v1] = T.float16(
T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps)
* T.float32(add_local[v1 // TX])
* T.float32(C[v1])
)

return prefill_add_rms


@tvm.transform.module_pass(opt_level=0, name="FuseAddRMSNorm")
class FuseAddRMSNorm: # pylint: disable=too-few-public-methods
"""A compiler pass that fuses add + rms_norm."""

def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule:
"""IRModule-level transformation."""
with PatternContext() as ctx:
pat_x1 = wildcard()
pat_x2 = wildcard()
pat_y = is_op("relax.add")(pat_x1, pat_x2)
pat_w = wildcard()
pat_o = is_op("relax.nn.rms_norm")(pat_y, pat_w)

def rewriter(matchings, bindings):
x1 = matchings[pat_x1]
x2 = matchings[pat_x2]
weight = matchings[pat_w]
y = matchings[pat_y]
o = matchings[pat_o]
eps = bindings[o].attrs.epsilon
if x1.struct_info.dtype != "float16":
return {}
n, _, h = x1.struct_info.shape
func_name = "fuse_add_norm_prefill" if n == 1 else "fuse_add_norm_decode"

if all(gv.name_hint != func_name for gv in mod.functions):
h = int(h)
if h % TX != 0:
return {}
if n == 1:
func = _get_add_rms_norm_prefill(h, eps)
else:
func = _get_add_rms_norm_decode(h, eps)
mod[func_name] = func
gvar = mod.get_global_var(func_name)
relax.expr._update_struct_info( # pylint: disable=protected-access
gvar,
relax.ObjectStructInfo(),
)
else:
gvar = mod.get_global_var(func_name)
o_y_tuple = relax.call_tir(
gvar,
[x1, x2, weight],
out_sinfo=[x1.struct_info, x1.struct_info],
)
return {
o: relax.TupleGetItem(o_y_tuple, 0),
y: relax.TupleGetItem(o_y_tuple, 1),
}

new_mod = {}
for gvar, func in mod.functions.items():
if isinstance(func, relax.Function):
func = rewrite_bindings(ctx, rewriter, func)
new_mod[gvar] = func

for gvar, func in mod.functions.items():
if isinstance(func, tvm.tir.PrimFunc) and gvar not in new_mod:
new_mod[gvar] = func

new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos)
return new_mod
2 changes: 2 additions & 0 deletions python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .clean_up_tir_attrs import CleanUpTIRAttrs
from .cublas_dispatch import CublasDispatch
from .estimate_memory_usage import AttachMetadataWithMemoryUsage
from .fuse_add_norm import FuseAddRMSNorm
from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise
from .fuse_dequantize_take import FuseDequantizeTake
from .fuse_dequantize_transpose import FuseDequantizeTranspose
Expand Down Expand Up @@ -92,6 +93,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
FuseFTDequantizeEpilogue(),
FuseDequantizeTranspose(),
CublasDispatch() if cublas_gemm else tvm.transform.Sequential([]),
FuseAddRMSNorm(),
FuseTransposeMatmul(),
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
Expand Down
14 changes: 12 additions & 2 deletions python/mlc_chat/interface/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from mlc_chat.model import MODELS
from mlc_chat.support import logging
from mlc_chat.support.auto_device import device2str
from mlc_chat.support.constants import MLC_CACHE_DIR, MLC_JIT_POLICY, MLC_TEMP_DIR
from mlc_chat.support.constants import (
MLC_CACHE_DIR,
MLC_DSO_SUFFIX,
MLC_JIT_POLICY,
MLC_TEMP_DIR,
)
from mlc_chat.support.style import blue, bold

from .compiler_flags import ModelConfigOverride, OptimizationFlags
Expand All @@ -26,6 +31,11 @@

def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path:
"""Just-in-time compile a MLC-Chat model."""
logger.info(
"%s = %s. Can be one of: ON, OFF, REDO, READONLY",
bold("MLC_JIT_POLICY"),
MLC_JIT_POLICY,
)
if MLC_JIT_POLICY == "OFF":
raise RuntimeError("JIT is disabled by MLC_JIT_POLICY=OFF")

Expand Down Expand Up @@ -64,7 +74,7 @@ def _get_model_config() -> Dict[str, Any]:

def _run_jit(opt: str, overrides: str, device: str, dst: str):
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir:
dso_path = os.path.join(tmp_dir, "lib.so")
dso_path = os.path.join(tmp_dir, f"lib.{MLC_DSO_SUFFIX}")
cmd = [
sys.executable,
"-m",
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, lay

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return op.ccl_allreduce(out, "sum") + residual
return out + residual


Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/mistral/mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def forward( # pylint: disable=too-many-arguments

def _apply_residual(out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return op.ccl_allreduce(out, "sum") + residual
return out + residual

out = self.self_attn(
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/mixtral/mixtral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def forward( # pylint: disable=too-many-arguments

def _apply_residual(out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return op.ccl_allreduce(out, "sum") + residual
return out + residual

out = self.self_attn(
Expand Down
11 changes: 11 additions & 0 deletions python/mlc_chat/support/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@ def _get_cache_dir() -> Path:
return result


def _get_dso_suffix() -> str:
if "MLC_DSO_SUFFIX" in os.environ:
return os.environ["MLC_DSO_SUFFIX"]
if sys.platform == "win32":
return "dll"
if sys.platform == "darwin":
return "dylib"
return "so"


MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)
MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None)
MLC_CACHE_DIR: Path = _get_cache_dir()
MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON")
MLC_DSO_SUFFIX = _get_dso_suffix()


_check()