|
21 | 21 |
|
22 | 22 | def build_fp8_linear(model: nn.Module, job_config: JobConfig):
|
23 | 23 | """
|
24 |
| - This function converts the linear layers to one of the fp8 types: |
25 |
| - - Float8DynamicLinear: Dynamic quantization of the weights and the activations |
26 |
| - - [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations |
| 24 | + This function converts the linear layers to `Float8Linear`. Note that today, |
| 25 | + only dynamic tensor scaling (the default) is supported. |
27 | 26 |
|
28 | 27 | This will mutate the model inplace.
|
29 | 28 | """
|
30 |
| - linear_type = job_config.training.fp8_linear.lower() |
| 29 | + use_fp8_linear = job_config.training.fp8_linear |
31 | 30 | try:
|
32 |
| - from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
| 31 | + # from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
33 | 32 |
|
34 |
| - # from float8_experimental.float8_linear import Float8Linear |
| 33 | + from float8_experimental.float8_linear import Float8Linear |
35 | 34 | from float8_experimental.float8_linear_utils import (
|
36 | 35 | swap_linear_with_float8_linear,
|
37 | 36 | )
|
| 37 | + import float8_experimental.config as config |
| 38 | + config.enable_fsdp_fp8_all_gather = True |
38 | 39 | except ImportError as exc:
|
39 | 40 | raise ImportError(
|
40 | 41 | "float8_experimental is not installed. Please install it to use fp8 linear layers."
|
41 | 42 | ) from exc
|
42 |
| - if linear_type: |
43 |
| - linear_type_map = { |
44 |
| - # "delayed": Float8Linear, # TODO: add "delayed" option back in when supported |
45 |
| - "dynamic": Float8DynamicLinear, |
46 |
| - } |
47 |
| - assert ( |
48 |
| - linear_type in linear_type_map |
49 |
| - ), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}." |
50 |
| - float8_linear_type = linear_type_map[linear_type.lower()] |
51 |
| - |
52 |
| - # Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type |
53 |
| - swap_linear_with_float8_linear(model, float8_linear_type) |
54 |
| - logger.info(f"Swapped to {linear_type} float8 linear layers") |
| 43 | + if use_fp8_linear: |
| 44 | + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear |
| 45 | + swap_linear_with_float8_linear(model, Float8Linear) |
| 46 | + logger.info(f"Swapped to Float8Linear layers") |
0 commit comments