Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of training with NANOO fp8 GEMM on AMD MI300/MI325 GPUs. #1262

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dtype: "bfloat16"
# 'int8' for dynamic range quantization using 8-bits
# 'intmp' for mixed precision quantization for inference as described here: MaxText/configs/quantization/README.md
# 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
# 'nanoo_fp8' for 8-bit floating-point GeMMs on AMD MI300/MI325 GPUs.
quantization: ""
# Choose one of default, high, and highest.
# https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
Expand Down
1 change: 1 addition & 0 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def main(argv: Sequence[str]) -> None:
tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length])
assert true_length <= config.max_prefill_predict_length, "can't take too many tokens"
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"
assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet"

# Split RNG before calling prefill
rng, rng_prefill = jax.random.split(rng)
Expand Down
15 changes: 15 additions & 0 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
return nn.Fp8DotGeneralOp


@dataclass
class NANOOFp8Quantization(Quantization):
"""Configures NANOO Fp8 quantization for AMD MI300/MI325 GPUs"""

quant_mode = "train"

def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.NANOOFp8DotGeneralOp


def _get_int8_quant_config(config):
drhs_bits = None
drhs_accumulator_dtype = None
Expand Down Expand Up @@ -272,6 +283,8 @@ def _get_quant_config(config):
return _get_mixed_precision_quant_config(mixed_precision_config)
if config.quantization == "fp8":
return "fp8"
if config.quantization == "nanoo_fp8":
return "nanoo_fp8"
raise ValueError(f"Invalid value configured for quantization {config.quantization}.")


Expand Down Expand Up @@ -302,6 +315,8 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"):
if quant_cfg:
if quant_cfg == "fp8":
return Fp8Quantization()
elif quant_cfg == "nanoo_fp8":
return NANOOFp8Quantization()
quant_mode = get_quant_mode(quant_mode_str)
replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False
return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale)
Expand Down
30 changes: 30 additions & 0 deletions MaxText/tests/pipeline_parallelism_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,36 @@ def test_full_train_fp8(self):
]
)

def test_full_train_nanoo_fp8(self):
# Run a full train.py call with NANOO fp8 quantization, which adds extra
# variable collections that need to be handled
train_main(
[
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_pipeline_parallelism_nanoo_fp8_test",
r"dataset_path=gs://maxtext-dataset",
"base_emb_dim=28",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"base_mlp_dim=32",
"base_num_decoder_layers=4",
"head_dim=128",
"per_device_batch_size=2",
"max_target_length=1024",
"vocab_size=32",
"dataset_type=synthetic",
"steps=3",
"enable_checkpointing=False",
"ici_pipeline_parallelism=4",
"tokenizer_path=../assets/tokenizer.llama2",
"quantization=nanoo_fp8",
"scan_layers=False",
"attention=dot_product",
]
)


if __name__ == "__main__":
unittest.main()
15 changes: 15 additions & 0 deletions MaxText/tests/train_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
r"tokenizer_path=../assets/tokenizer.llama2",
],
"nanoo_fp8": [ # tests base config with nanoo_fp8
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_test",
r"dataset_path=gs://maxtext-dataset",
"quantization=nanoo_fp8",
"steps=2",
"enable_checkpointing=False",
r"tokenizer_path=../assets/tokenizer.llama2",
],
"dropout": [ # tests base config with dropout
None,
"configs/base.yml",
Expand Down Expand Up @@ -148,6 +159,10 @@ def test_tpu_fp8(self):
def test_gpu_fp8(self):
train_main(TrainTests.CONFIGS["fp8"] + ["attention=dot_product"])

@pytest.mark.gpu_only
def test_gpu_nanoo_fp8(self):
train_main(TrainTests.CONFIGS["nanoo_fp8"] + ["attention=dot_product"])

@pytest.mark.tpu_only
def test_tpu_dropout(self):
train_main(TrainTests.CONFIGS["dropout"])
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def validate_train_config(config):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."

if config.quantization == "fp8":
if config.quantization in ("fp8", "nanoo_fp8"):
# pylint: disable=line-too-long
assert (
config.gradient_accumulation_steps == 1
Expand Down
1 change: 1 addition & 0 deletions benchmarks/mmlu/mmlu_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def main(config):
tokens = tokens[:max_prefill_predict_length]
true_length = max_prefill_predict_length
assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"
assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet"

# Perform prefill
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
Expand Down
Loading