From cb69421321b924a9b21690785c7c20996aae7929 Mon Sep 17 00:00:00 2001 From: kaixih Date: Fri, 20 Dec 2024 18:44:02 +0000 Subject: [PATCH] Support fp8 quant for Moe layer --- MaxText/layers/linears.py | 4 +- MaxText/layers/quantizations.py | 72 +++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 1820a26bc..94f0be272 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -581,7 +581,9 @@ def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = (), einsum_name= def aqt_einsum(*args, **kwargs): # simply skip kwargs, since aqt einsum doesn't support any kwargs like precision - return self.quant.einsum(rhs_mesh_axes)(*args) + is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization) + kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} + return self.quant.einsum(**kw)(*args) einsum_op = aqt_einsum else: diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 8bfebf021..05ed3954b 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -25,6 +25,8 @@ from aqt.jax.v2 import calibration import common_types from dataclasses import dataclass +from flax.linen import fp8_ops +from flax.linen import initializers as flax_initializers import flax.linen as nn import jax import jax.numpy as jnp @@ -45,6 +47,7 @@ Array = common_types.Array Config = common_types.Config +DType = common_types.DType AxisIdxes = common_types.AxisIdxes AxisNames = common_types.AxisNames CACHE_HEADS = common_types.CACHE_HEADS @@ -60,6 +63,10 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Placeholder for dot_general implementation in subclasses.""" pass + def einsum(self, dtype: DType = jnp.float32): + """Placeholder for einsum implementation in subclasses.""" + pass + def _tiling_fn(lhs, rhs, dimension_numbers, tile_size): del lhs, rhs @@ -201,6 +208,71 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" return nn.Fp8DotGeneralOp + def einsum(self, dtype: DType = jnp.float32): + return Fp8Einsum(dtype=dtype) + + +class Fp8Einsum(nn.Module): + """An fp8 einsum op. + + Attributes: + amax_history_length: size of the amax history. + e4m3_dtype: e4m3 variants, e.g., e4m3fn, e4m3fnuz. + e5m2_dtype: e5m2 variants, e.g., e5m2, e5m2fnuz. + dtype: computation dtype. + """ + + amax_history_length: int = 1024 + e4m3_dtype: DType = jnp.float8_e4m3fn + e5m2_dtype: DType = jnp.float8_e5m2 + dtype: DType = jnp.float32 + + def setup(self) -> None: + scale_args = ( + flax_initializers.ones_init(), + jax.random.PRNGKey(0), + (1,), + jnp.float32, + ) + amax_history_args = ( + flax_initializers.zeros_init(), + jax.random.PRNGKey(0), + (self.amax_history_length,), + jnp.float32, + ) + + OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" + self.input_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "input_amax_history", *amax_history_args) + self.kernel_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_amax_history", *amax_history_args) + self.output_grad_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_amax_history", *amax_history_args) + + self.input_scale = self.variable(OVERWRITE_WITH_GRADIENT, "input_scale", *scale_args) + self.kernel_scale = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_scale", *scale_args) + self.output_grad_scale = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_scale", *scale_args) + + def __call__(self, eqn, *args, **kwargs): + assert len(args) == 2 + x = args[0] + k = args[1] + + comp_dtype = self.dtype + k = jnp.asarray(k, comp_dtype) + x = jnp.asarray(x, comp_dtype) + + x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) + k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value) + + y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision) + + y = fp8_ops.out_qdq( + comp_dtype, + self.e5m2_dtype, + y_qdq, + self.output_grad_scale.value, + self.output_grad_amax_history.value, + ) + return y + def _get_int8_quant_config(config): drhs_bits = None