Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 025929c

Browse files
committed
Update on "one more delayed -> dynamic default update"
Summary: missed this in #300 Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 87dcf39 + 2a9433a commit 025929c

13 files changed

+134
-31
lines changed

README.md

+14-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
3737
from float8_experimental.float8_linear_utils import (
3838
swap_linear_with_float8_linear,
3939
)
40+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
4041
from float8_experimental.float8_linear import Float8Linear
4142

4243
# create model
@@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True)
5152
# optional: enable torch.compile for improved performance
5253
m = torch.compile(m)
5354

54-
# train/finetune (not shown)
55+
# toy training loop
56+
for _ in range(N_ITER):
57+
optimizer.zero_grad()
58+
y = m(x)
59+
y.sum().backward()
60+
optimizer.step()
61+
62+
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
63+
# this method is optional but is highly recommended for performance
64+
# it calcuclates scales for all parameters in a single all-reduce
65+
precompute_float8_dynamic_scale_for_fsdp(model)
66+
5567
```
5668

5769
## float8 linear with delayed scaling
@@ -71,7 +83,7 @@ m = Model(...)
7183
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
7284
# type
7385
swap_linear_with_float8_linear(
74-
m,
86+
m,
7587
Float8Linear,
7688
scaling_type_x=TensorScalingType.DELAYED,
7789
scaling_type_w=TensorScalingType.DELAYED,

benchmarks/bench_multi_gpu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.multiprocessing as mp
1515
import torch.nn as nn
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
17+
from float8_experimental.float8_linear import TensorScalingType
1818
from float8_experimental.float8_linear_utils import (
1919
swap_linear_with_float8_linear,
2020
sync_float8_amax_and_scale_history,

benchmarks/profile_linear_float8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
21+
from float8_experimental.float8_linear import TensorScalingType
2222
from float8_experimental.float8_linear_utils import (
2323
linear_requires_sync,
2424
swap_linear_with_float8_linear,

float8_experimental/float8_dynamic_utils.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99

1010
from typing import Any, Optional, Tuple
1111

12-
import float8_experimental.config as config
13-
1412
import torch
15-
import torch.nn as nn
1613
import torch.utils._pytree as pytree
1714

1815
from float8_experimental.float8_tensor import (
@@ -85,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw(
8582

8683
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
8784
@staticmethod
88-
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
85+
def __new__(
86+
cls,
87+
tensor: torch.Tensor,
88+
mm_config: ScaledMMConfig,
89+
precomputed_scale: Optional[torch.Tensor] = None,
90+
):
8991
return torch.Tensor._make_wrapper_subclass(
9092
cls,
9193
tensor.size(),
@@ -99,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
99101
requires_grad=tensor.requires_grad,
100102
)
101103

102-
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
104+
def __init__(
105+
self,
106+
tensor: torch.Tensor,
107+
mm_config: ScaledMMConfig,
108+
precomputed_scale: Optional[torch.Tensor] = None,
109+
):
103110
self._tensor = tensor
104111
self._mm_config = mm_config
112+
# for dynamic scaling
113+
# `precompute_float8_dynamic_scale_for_fsdp` calculates scales
114+
# for all float8 parameters after optimizer step
115+
self._precomputed_scale = precomputed_scale
105116

106117
@classmethod
107118
def __torch_dispatch__(cls, func, types, args, kwargs=None):
@@ -130,20 +141,35 @@ def unwrap(t):
130141
)
131142

132143
def __tensor_flatten__(self):
133-
return ["_tensor"], self._mm_config
144+
if self._precomputed_scale:
145+
return ["_tensor", "_precomputed_scale"], self._mm_config
146+
else:
147+
return ["_tensor"], self._mm_config
134148

135149
@staticmethod
136150
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
137151
mm_config = flatten_spec
138-
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
152+
return WeightWithDynamicFloat8CastTensor(
153+
inner_tensors["_tensor"],
154+
mm_config,
155+
getattr(inner_tensors, "_precomputed_scale", None),
156+
)
139157

140158
def __repr__(self):
141159
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
142160

143161
def fsdp_pre_all_gather(self, mesh):
144-
float8_tensor = cast_to_float8_e4m3_dynamic(
145-
self._tensor, self._mm_config, reduce_amax=True
146-
)
162+
if self._precomputed_scale is not None:
163+
float8_tensor = Float8Tensor.to_float8(
164+
self._tensor,
165+
self._precomputed_scale,
166+
torch.float8_e4m3fn,
167+
mm_config=self._mm_config,
168+
)
169+
else:
170+
float8_tensor = cast_to_float8_e4m3_dynamic(
171+
self._tensor, self._mm_config, reduce_amax=True
172+
)
147173
return (float8_tensor._data,), (float8_tensor._scale,)
148174

149175
def fsdp_post_all_gather(

float8_experimental/float8_linear_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import copy
76
import logging
8-
from enum import auto, Enum
9-
from typing import Callable, List, Optional, Type, Union
7+
from typing import Callable, List, Optional
108

119
import torch
1210
import torch.distributed as dist

float8_experimental/fsdp_utils.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import math
2+
from typing import List
3+
4+
import torch
5+
import torch.nn as nn
6+
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
7+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
8+
from float8_experimental.float8_utils import EPS
9+
10+
11+
@torch.no_grad()
12+
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
13+
"""
14+
Calculate scale dynamically for all float8 parameters.
15+
This should be run after the optimizer step. It performs a single all-reduce to compute the
16+
scales for all float8 weights.
17+
Example usage:
18+
model(input).sum().backward()
19+
optim.step()
20+
precompute_float8_dynamic_scale_for_fsdp(model)
21+
"""
22+
from torch.distributed._tensor import DTensor
23+
24+
if any(
25+
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED
26+
for m in module.modules()
27+
):
28+
raise NotImplementedError("Only supports delayed scaling")
29+
float8_linears: List[Float8Linear] = [
30+
m
31+
for m in module.modules()
32+
if isinstance(m, Float8Linear)
33+
and isinstance(m.weight, DTensor)
34+
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
35+
]
36+
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]
37+
38+
if not weights:
39+
return
40+
41+
# inf-norm is equivalent to max(abs(w))
42+
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
43+
amax_tensor = torch.vstack(max_weights) # Partial
44+
# clamp is dispatched through DTensor
45+
# it will issue a single all-reduce
46+
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
47+
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
48+
if amax_tensor.dtype is torch.float16:
49+
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
50+
scales = torch.split(scale_tensor, 1) # Replicate
51+
for scale, float8_linear in zip(scales, float8_linears):
52+
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor

test/test_dtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn.functional as F
1616

1717
from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw
18-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
18+
from float8_experimental.float8_linear import TensorScalingType
1919
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
2020
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
2121
from float8_experimental.float8_tensor_parallel import (

test/test_fsdp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.distributed as dist
2222
import torch.multiprocessing as mp
2323
import torch.nn as nn
24-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
24+
from float8_experimental.float8_linear import TensorScalingType
2525
from float8_experimental.float8_linear_utils import (
2626
linear_requires_sync,
2727
swap_linear_with_float8_linear,
@@ -149,7 +149,7 @@ def forward_backward(model, optim, is_fp8, i):
149149
model_fp8 = torch.compile(model_fp8)
150150
y_local = forward_backward(model, optimizer, is_fp8=False, i=i)
151151
y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i)
152-
local_sqnr = compute_error(y_local, y_local_fp8)
152+
local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841
153153

154154
# get global y
155155
y_global = [

test/test_fsdp2/test_fsdp2_common.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import contextlib
2-
from typing import List, Type
2+
from typing import List
33

44
import float8_experimental.config as config
55

66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9-
from float8_experimental.float8_linear import Float8Linear
9+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
1010

1111

1212
def check_parity_no_mp(
@@ -16,6 +16,7 @@ def check_parity_no_mp(
1616
fsdp_model: nn.Module,
1717
fsdp_optim: torch.optim.Optimizer,
1818
local_inp: torch.Tensor,
19+
precompute: bool = False,
1920
):
2021
for iter_idx in range(10):
2122
losses: List[torch.Tensor] = []
@@ -29,6 +30,8 @@ def check_parity_no_mp(
2930
param.grad.div_(dist.get_world_size())
3031
# TODO(future): add amax syncing once delayed scaling is supported
3132
optim.step()
33+
if model is fsdp_model and precompute:
34+
precompute_float8_dynamic_scale_for_fsdp(model)
3235
test_cls.assertEqual(losses[0], losses[1])
3336

3437

test/test_fsdp2/test_fsdp2_eager.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import itertools
32
import threading
43
import unittest
54
from typing import Any, List
@@ -9,7 +8,7 @@
98
import torch.distributed as dist
109
import torch.nn as nn
1110
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
12-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
11+
from float8_experimental.float8_linear import TensorScalingType
1312
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
1413
from test_fsdp2_common import (
1514
check_parity_bf16_mp,
@@ -87,10 +86,21 @@ def world_size(self) -> int:
8786

8887
@skip_if_lt_x_gpu(2)
8988
def test_transformer_parity_dynamic(self):
90-
for enable_fsdp_fp8_all_gather in [False, True]:
91-
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
89+
self.run_subtests(
90+
{
91+
"enable_fsdp_fp8_all_gather": [False, True],
92+
"precompute": [False, True],
93+
},
94+
self._test_transformer_parity_dynamic,
95+
)
9296

93-
def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
97+
def _test_transformer_parity_dynamic(
98+
self,
99+
enable_fsdp_fp8_all_gather: bool,
100+
precompute: bool,
101+
):
102+
if not enable_fsdp_fp8_all_gather and precompute:
103+
return
94104
# NOTE: Weight-tying does not compose with fp8 all-gather because the
95105
# embedding weight and output linear weight are tied but only the
96106
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
@@ -110,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
110120
local_inp = torch.randint(
111121
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
112122
)
113-
check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp)
123+
check_parity_no_mp(
124+
self, ref_module, ref_optim, module, optim, local_inp, precompute
125+
)
114126

115127
@skip_if_lt_x_gpu(2)
116128
def test_transformer_memory(self):

test/test_fsdp_compile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.multiprocessing as mp
1919
import torch.nn as nn
2020
from float8_experimental import config
21-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
21+
from float8_experimental.float8_linear import TensorScalingType
2222
from float8_experimental.float8_linear_utils import (
2323
swap_linear_with_float8_linear,
2424
sync_float8_amax_and_scale_history,

test/test_inference_flows.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
import torch.nn as nn
1515
import torch.nn.functional as F
16-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
16+
from float8_experimental.float8_linear import TensorScalingType
1717
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
1818
from float8_experimental.float8_tensor import Float8Tensor
1919
from float8_experimental.float8_utils import compute_error

test/test_numerics_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
import torch.nn.functional as F
17-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
17+
from float8_experimental.float8_linear import TensorScalingType
1818
from float8_experimental.float8_linear_utils import (
1919
linear_requires_sync,
2020
swap_linear_with_float8_linear,

0 commit comments

Comments
 (0)