Skip to content

Commit

Permalink
Fix fp8-all-gather buck errors
Browse files Browse the repository at this point in the history
Differential Revision: D63048850
  • Loading branch information
y-sq authored and facebook-github-bot committed Sep 20, 2024
1 parent 0bdde92 commit 09da3c7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 8 additions & 1 deletion test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor
from torch.testing._internal.common_cuda import TEST_CUDA
Expand All @@ -36,6 +35,14 @@
TransformerBlock,
)

# OSS and fbcode need different import statements
# TODO: fix the issue and remove the try-except block.
try:
from test_fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
except ImportError:
from .test_fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp


is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
if not is_cuda_8_9:
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def check_parity_no_mp(
precompute_float8_dynamic_scale_for_fsdp(model)

if compile_transformer_block:
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
else:
test_cls.assertEqual(losses[0], losses[1])
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")


def check_parity_bf16_mp(
Expand Down Expand Up @@ -86,4 +86,4 @@ def check_parity_bf16_mp(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_bf16.detach().copy_(param_fp32)
test_cls.assertEqual(losses[0], losses[1])
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")

0 comments on commit 09da3c7

Please sign in to comment.