Skip to content

Commit faed733

Browse files
committed
switch float8 logic from Float8DynamicLinear to Float8Linear
Summary: In the stack ending in pytorch-labs/float8_experimental#300 in float8_experimental, we are unifying `Float8DynamicLinear` and `Float8Linear`, with a future PR being planned to delete the `Float8DynamicLinear` object. After pytorch-labs/float8_experimental#300, `Float8Linear` with default settings is equivalent to `Float8DynamicLinear`. This PR changes `torchtitan` to use `Float8Linear`. To support the new UX of `float8_experimental` better, I also switched the `fp8_linear` configuration to be a boolean on whether to swap the linears or not. In the future we can add new options on how to configure each linear (scaling type, scaling granularity, etc) - saving that for a future PR. Test Plan: ``` // run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs, // verify performance and loss values do not change meaningfully between // baseline and this PR // baseline (before this PR) // 1. compile, bf16 // 2. compile, float8 // 3. compile, float8, fdsp_fp8_allgather=True // 4. compile, float8, fdsp_fp8_allgather=True, tp=2 // logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce // experiment (this PR): repeat all of the above, but with Float8Linear // logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent b0ed7f0 commit faed733

9 files changed

+25
-36
lines changed

torchtitan/config_manager.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,11 @@ def __init__(self):
339339
)
340340
self.parser.add_argument(
341341
"--training.fp8_linear",
342-
type=str,
343-
default="",
344-
choices=[
345-
"dynamic",
346-
"",
347-
], # TODO: add "delayed" option back in when supported
342+
action="store_true",
348343
help="""
349-
Type of fp8 linear quantization to apply to the model ['', 'dynamic'].
350-
This features requires you to install 'float8_experimental' which can be found
344+
If true, swaps `torch.nn.Linear` with `Float8Linear` with
345+
default settings (dynamic scaling).
346+
This feature requires you to install 'float8_experimental' which can be found
351347
here: https://github.com/pytorch-labs/float8_experimental
352348
""",
353349
)

torchtitan/float8_linear.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,26 @@
2121

2222
def build_fp8_linear(model: nn.Module, job_config: JobConfig):
2323
"""
24-
This function converts the linear layers to one of the fp8 types:
25-
- Float8DynamicLinear: Dynamic quantization of the weights and the activations
26-
- [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations
24+
This function converts the linear layers to `Float8Linear`. Note that today,
25+
only dynamic tensor scaling (the default) is supported.
2726
2827
This will mutate the model inplace.
2928
"""
30-
linear_type = job_config.training.fp8_linear.lower()
29+
use_fp8_linear = job_config.training.fp8_linear
3130
try:
32-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
31+
# from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
3332

34-
# from float8_experimental.float8_linear import Float8Linear
33+
from float8_experimental.float8_linear import Float8Linear
3534
from float8_experimental.float8_linear_utils import (
3635
swap_linear_with_float8_linear,
3736
)
37+
import float8_experimental.config as config
38+
config.enable_fsdp_fp8_all_gather = True
3839
except ImportError as exc:
3940
raise ImportError(
4041
"float8_experimental is not installed. Please install it to use fp8 linear layers."
4142
) from exc
42-
if linear_type:
43-
linear_type_map = {
44-
# "delayed": Float8Linear, # TODO: add "delayed" option back in when supported
45-
"dynamic": Float8DynamicLinear,
46-
}
47-
assert (
48-
linear_type in linear_type_map
49-
), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}."
50-
float8_linear_type = linear_type_map[linear_type.lower()]
51-
52-
# Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type
53-
swap_linear_with_float8_linear(model, float8_linear_type)
54-
logger.info(f"Swapped to {linear_type} float8 linear layers")
43+
if use_fp8_linear:
44+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
45+
swap_linear_with_float8_linear(model, Float8Linear)
46+
logger.info(f"Swapped to Float8Linear layers")

torchtitan/parallelisms/parallelize_llama.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,10 @@ def apply_compile(model, job_config: JobConfig):
441441
ac_config = job_config.activation_checkpoint
442442
if ac_config.mode == "selective" and ac_config.selective_ac_option == "op":
443443
# some temp flags for torch.compile enablement + SAC
444-
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
445-
True
446-
)
444+
# torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
445+
# True
446+
# )
447+
pass
447448

448449
logger.info("Compiled each TransformerBlock with torch.compile")
449450
return model

train_configs/debug_model.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
3737
steps = 10
3838
data_parallel_degree = -1
3939
tensor_parallel_degree = 1
40-
fp8_linear = ""
40+
fp8_linear = false
4141
compile = false
4242
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)
4343

train_configs/llama2_13b.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_70b.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_7b.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
3232
steps = 1000
3333
data_parallel_degree = -1
3434
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
35-
fp8_linear = ""
35+
fp8_linear = false
3636
compile = false
3737
dataset = "c4"
3838

train_configs/llama3_70b.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama3_8b.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

0 commit comments

Comments
 (0)