Skip to content

Commit cffbf1e

Browse files
committed
update float8 integration after UX changes
Summary: float8_experimental landed various BC-breaking UX changes last week. This PR updates torchtitan to work with the version of float8_experimental after pytorch-labs/float8_experimental#332 Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 0f70507 commit cffbf1e

File tree

1 file changed

+15
-25
lines changed

1 file changed

+15
-25
lines changed

torchtitan/float8_linear.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15-
import contextlib
1615
import functools
1716
from typing import Optional
1817

@@ -24,20 +23,6 @@
2423
from torchtitan.logging_utils import logger
2524

2625

27-
@contextlib.contextmanager
28-
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
29-
import float8_experimental.config as config
30-
31-
prev = config.enable_fsdp_fp8_all_gather
32-
torch.distributed.barrier()
33-
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
34-
try:
35-
yield
36-
finally:
37-
torch.distributed.barrier()
38-
config.enable_fsdp_fp8_all_gather = prev
39-
40-
4126
@functools.lru_cache(None)
4227
def is_sm90_or_later():
4328
# Float8 is only supported on H100+ GPUs
@@ -63,21 +48,26 @@ def maybe_build_fp8_linear(
6348
)
6449
return
6550
try:
66-
from float8_experimental.float8_linear import TensorScalingType
67-
from float8_experimental.float8_linear_utils import (
68-
swap_linear_with_float8_linear,
51+
from float8_experimental import (
52+
CastConfig,
53+
convert_to_float8_training,
54+
Float8LinearConfig,
55+
ScalingType,
6956
)
7057

7158
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
7259
enable_fsdp_float8_all_gather = (
7360
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
7461
)
75-
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
76-
swap_linear_with_float8_linear(
77-
model,
78-
scaling_type_w=TensorScalingType.DYNAMIC,
79-
skip_fqn_list=["output"],
80-
)
62+
float8_config = Float8LinearConfig(
63+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
64+
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
65+
)
66+
convert_to_float8_training(
67+
model,
68+
config=float8_config,
69+
module_filter_fn=lambda mod, fqn: fqn != "output",
70+
)
8171
logger.info(
8272
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
8373
)
@@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
10292
"Skipped precomputing fp8 scales because SM90 or later is not available",
10393
)
10494
return
105-
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
95+
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp
10696

10797
precompute_float8_dynamic_scale_for_fsdp(model)

0 commit comments

Comments
 (0)