From dd4921e2c2752de4b0149eb88fe554f14e854c5b Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Thu, 8 Jun 2023 22:36:42 +0000 Subject: [PATCH 1/6] first file added --- tests/unit/runtime/half_precision/test_fp8.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/unit/runtime/half_precision/test_fp8.py diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py new file mode 100644 index 000000000000..981ad336e9a8 --- /dev/null +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +import pytest +from unit.common import DistributedTest +from unit.simple_model import random_dataloader +import transformer_engine.pytorch as te +from transformer_engine.common import recipe + +@pytest.mark.parametrize("base_datatype", ["fp16"]) +class TestFp8ComposabilityAcrossZero(DistributedTest): + world_size = 1 + + def test(self, base_datatype): + hidden_dim = 4096 + if base_datatype == "fp16": + model_dtype = torch.float16 + elif base_datatype == "bf16": + model_dtype = torch.bfloat16 + else: + model_dtype = torch.float32 + + torch.random.manual_seed(15) + # TransformerEngine Model + model = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype).cuda() + + # Create FP8 recipe. Note: All input args are optional. + fp8_recipe = recipe.DelayedScaling(fp8_format=recipe.Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max") + # config + zero_stage = [0, 1, 2, 3] + for stage in zero_stage: + config = { + "train_batch_size": 8, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001 + } + }, + "zero_optimization": { + "stage": stage, + }, + "fp16": { + "enabled": True, + "loss_scale": 0.1 + }, + "bf16": { + "enabled": False + } + } + # Init DeepSpeed + model, optimizer, _, _ = deepspeed.initialize(args=None, model=model, + model_parameters=model.parameters(), config=config) + + data = torch.randn(128, hidden_dim, device=model.device, dtype=model_dtype) + for datum in data: + # Enables autocasting for the forward pass + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + out = model(datum) + loss = out.mean() + model.backward(loss) + model.step() + print(loss) From c2818eac6cabe972ec523653ed1fcc603edb8809 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Fri, 9 Jun 2023 18:48:32 +0000 Subject: [PATCH 2/6] full fp8 test --- tests/unit/runtime/half_precision/test_fp8.py | 74 +++++++++++-------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py index 981ad336e9a8..9575554f39e9 100644 --- a/tests/unit/runtime/half_precision/test_fp8.py +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -11,32 +11,32 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe -@pytest.mark.parametrize("base_datatype", ["fp16"]) + +@pytest.mark.parametrize("base_datatype", ["fp16", "bf16", "fp32"]) class TestFp8ComposabilityAcrossZero(DistributedTest): world_size = 1 def test(self, base_datatype): - hidden_dim = 4096 - if base_datatype == "fp16": - model_dtype = torch.float16 - elif base_datatype == "bf16": - model_dtype = torch.bfloat16 - else: - model_dtype = torch.float32 - - torch.random.manual_seed(15) - # TransformerEngine Model - model = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype).cuda() - - # Create FP8 recipe. Note: All input args are optional. - fp8_recipe = recipe.DelayedScaling(fp8_format=recipe.Format.HYBRID, - amax_history_len=16, - amax_compute_algo="max") - # config - zero_stage = [0, 1, 2, 3] - for stage in zero_stage: + def run_zero(stage, model_dtype): + num_batches = 128 + batch_size = 16 + hidden_dim = 768 + torch.random.manual_seed(42) + enable_fp16 = False + enable_bf16 = False + if model_dtype == torch.float16: + enable_fp16 = True + elif model_dtype == torch.bfloat16: + enable_bf16 = True + # TransformerEngine Model + model = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype) + + # Create FP8 recipe. Note: All input args are optional. + fp8_recipe = recipe.DelayedScaling(fp8_format=recipe.Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max") config = { - "train_batch_size": 8, + "train_batch_size": batch_size, "gradient_accumulation_steps": 1, "optimizer": { "type": "Adam", @@ -45,26 +45,42 @@ def test(self, base_datatype): } }, "zero_optimization": { - "stage": stage, + "stage": stage }, "fp16": { - "enabled": True, + "enabled": enable_fp16, "loss_scale": 0.1 }, "bf16": { - "enabled": False + "enabled": enable_bf16 } } # Init DeepSpeed model, optimizer, _, _ = deepspeed.initialize(args=None, model=model, model_parameters=model.parameters(), config=config) - - data = torch.randn(128, hidden_dim, device=model.device, dtype=model_dtype) - for datum in data: + + batches = torch.randn(num_batches, batch_size, hidden_dim, device=model.device, dtype=model_dtype) + for batch in batches: # Enables autocasting for the forward pass with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - out = model(datum) + out = model(batch) loss = out.mean() model.backward(loss) model.step() - print(loss) + return loss + + if base_datatype == "fp16": + model_dtype = torch.float16 + elif base_datatype == "bf16": + model_dtype = torch.bfloat16 + else: + model_dtype = torch.float32 + + # config + zero_stage = [0,1,2,3] + losses = [] + for stage in zero_stage: + loss = run_zero(stage, model_dtype) + losses.append(loss) + all_equal = all(torch.allclose(loss, losses[0], 1e-07, 1e-05) for loss in losses) + assert (all_equal) From 7f263a63911cf9e043e974a0845674b8095dacec Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Fri, 9 Jun 2023 18:52:58 +0000 Subject: [PATCH 3/6] te to transformer_engine for precommit --- tests/unit/runtime/half_precision/test_fp8.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py index 9575554f39e9..45cc197c879c 100644 --- a/tests/unit/runtime/half_precision/test_fp8.py +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -7,8 +7,7 @@ import deepspeed import pytest from unit.common import DistributedTest -from unit.simple_model import random_dataloader -import transformer_engine.pytorch as te +import transformer_engine.pytorch as transformer_engine from transformer_engine.common import recipe @@ -17,10 +16,12 @@ class TestFp8ComposabilityAcrossZero(DistributedTest): world_size = 1 def test(self, base_datatype): + def run_zero(stage, model_dtype): num_batches = 128 batch_size = 16 hidden_dim = 768 + # Have to set seed before model torch.random.manual_seed(42) enable_fp16 = False enable_bf16 = False @@ -29,12 +30,12 @@ def run_zero(stage, model_dtype): elif model_dtype == torch.bfloat16: enable_bf16 = True # TransformerEngine Model - model = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype) - + model = transformer_engine.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype) + # Create FP8 recipe. Note: All input args are optional. fp8_recipe = recipe.DelayedScaling(fp8_format=recipe.Format.HYBRID, - amax_history_len=16, - amax_compute_algo="max") + amax_history_len=16, + amax_compute_algo="max") config = { "train_batch_size": batch_size, "gradient_accumulation_steps": 1, @@ -56,13 +57,15 @@ def run_zero(stage, model_dtype): } } # Init DeepSpeed - model, optimizer, _, _ = deepspeed.initialize(args=None, model=model, - model_parameters=model.parameters(), config=config) - + model, optimizer, _, _ = deepspeed.initialize(args=None, + model=model, + model_parameters=model.parameters(), + config=config) + batches = torch.randn(num_batches, batch_size, hidden_dim, device=model.device, dtype=model_dtype) for batch in batches: # Enables autocasting for the forward pass - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with transformer_engine.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): out = model(batch) loss = out.mean() model.backward(loss) @@ -77,7 +80,7 @@ def run_zero(stage, model_dtype): model_dtype = torch.float32 # config - zero_stage = [0,1,2,3] + zero_stage = [0, 1, 2, 3] losses = [] for stage in zero_stage: loss = run_zero(stage, model_dtype) From 2ffa289309d830bc642ac06d64e5d275341c6e51 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Fri, 9 Jun 2023 19:49:27 +0000 Subject: [PATCH 4/6] cleaner dtype check --- tests/unit/runtime/half_precision/test_fp8.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py index 45cc197c879c..44d4b578b0d5 100644 --- a/tests/unit/runtime/half_precision/test_fp8.py +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -23,12 +23,8 @@ def run_zero(stage, model_dtype): hidden_dim = 768 # Have to set seed before model torch.random.manual_seed(42) - enable_fp16 = False - enable_bf16 = False - if model_dtype == torch.float16: - enable_fp16 = True - elif model_dtype == torch.bfloat16: - enable_bf16 = True + enable_fp16 = model_dtype == torch.float16 + enable_bf16 = model_dtype == torch.bfloat16 # TransformerEngine Model model = transformer_engine.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype) From 255e60e129d7d72b9600bbe4fd0cd8942edfa0d9 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 9 Jun 2023 12:08:01 -0700 Subject: [PATCH 5/6] add skips for missing TE and older gpus --- tests/unit/runtime/half_precision/test_fp8.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py index 44d4b578b0d5..9f4814927d74 100644 --- a/tests/unit/runtime/half_precision/test_fp8.py +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -7,8 +7,13 @@ import deepspeed import pytest from unit.common import DistributedTest -import transformer_engine.pytorch as transformer_engine -from transformer_engine.common import recipe +from unit.util import skip_on_arch + +try: + import transformer_engine.pytorch as transformer_engine + from transformer_engine.common import recipe +except ImportError: + pytest.skip("Transformer Engine package is missing, skipping tests") @pytest.mark.parametrize("base_datatype", ["fp16", "bf16", "fp32"]) @@ -16,6 +21,7 @@ class TestFp8ComposabilityAcrossZero(DistributedTest): world_size = 1 def test(self, base_datatype): + skip_on_arch(min_arch=9) def run_zero(stage, model_dtype): num_batches = 128 From ffe4c17c5ba30f8fde68ca0ffda80dc21725ff0d Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 9 Jun 2023 12:21:42 -0700 Subject: [PATCH 6/6] add module skip flag --- tests/unit/runtime/half_precision/test_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py index 9f4814927d74..21217ed7dd82 100644 --- a/tests/unit/runtime/half_precision/test_fp8.py +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -13,7 +13,7 @@ import transformer_engine.pytorch as transformer_engine from transformer_engine.common import recipe except ImportError: - pytest.skip("Transformer Engine package is missing, skipping tests") + pytest.skip("Transformer Engine package is missing, skipping tests", allow_module_level=True) @pytest.mark.parametrize("base_datatype", ["fp16", "bf16", "fp32"])