Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 11d6db7

Browse files
committed
delete Float8DynamicLinear
Summary: We are standardizing on `Float8Linear` as the only float8 linear object: 1. the stack ending with #300 moved all of the functionality of `Float8DynamicLinear` to `Float8Linear`. The default settings of `Float8Linear` are to use dynamic scaling. 2. this PR deletes `Float8DynamicLinear` from the codebase and patches the relevant callsites in fbsource. Test Plan: ``` // all tests pass ./test_everything.sh // also run all benchmarks and verify correctness ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8ab483377124960fec2f133c0e27fbbaab204528 Pull Request resolved: #304
1 parent d4cf2ad commit 11d6db7

16 files changed

+182
-536
lines changed

benchmarks/bench_linear_float8.py

+15-50
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import TensorScalingType
17+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1818
from float8_experimental.float8_linear_utils import (
19-
get_float8_linear,
2019
linear_requires_sync,
21-
LinearType,
2220
sync_float8_amax_and_scale_history,
2321
)
2422
from float8_experimental.float8_tensor import ScaledMMConfig
@@ -69,7 +67,6 @@ class Experiment:
6967
dtype: torch.dtype
7068
compiled: bool
7169
use_fast_accum: bool
72-
linear_type: str
7370
scaling_repr: str
7471

7572
# 3 Times since we are calculating forward backward
@@ -98,7 +95,6 @@ def main(
9895
n_limit: Optional[int] = None,
9996
fast_accum_filter: Optional[bool] = None,
10097
shape_name_filter: Optional[str] = None,
101-
linear_type_filter: Optional[str] = None,
10298
scaling_type_x: str = "delayed",
10399
scaling_type_w: str = "delayed",
104100
scaling_type_dL_dY: str = "delayed",
@@ -123,44 +119,28 @@ def main(
123119
use_fast_accum = [fast_accum_filter]
124120
else:
125121
use_fast_accum = [True, False]
126-
if linear_type_filter is not None:
127-
linear_types = [linear_type_filter]
128-
else:
129-
linear_types = ["delayed", "dynamic"]
130122
if shape_name_filter is not None:
131123
k = shape_name_filter
132124
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
133125
experiment_list: List[Experiment] = []
134126
dtype = torch.bfloat16
135-
for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate(
136-
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types)))
127+
for idx, (fast_accum, (name, (K, N))) in enumerate(
128+
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
137129
):
138130
if n_limit is not None and idx >= n_limit:
139131
break
140132
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
141133
device=device, dtype=dtype
142134
)
143-
linear_type_enum = (
144-
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
145-
)
146135

147-
if linear_type == "delayed":
148-
linear_float8 = get_float8_linear(
149-
linear_type_enum,
150-
copy.deepcopy(linear_ref),
151-
emulate=False,
152-
scaling_type_x=scaling_type_x,
153-
scaling_type_w=scaling_type_w,
154-
scaling_type_dL_dY=scaling_type_dL_dY,
155-
)
156-
scaling_repr = linear_float8.scaling_repr()
157-
else:
158-
linear_float8 = get_float8_linear(
159-
linear_type_enum,
160-
copy.deepcopy(linear_ref),
161-
emulate=False,
162-
)
163-
scaling_repr = None
136+
linear_float8 = Float8Linear.from_float(
137+
copy.deepcopy(linear_ref),
138+
emulate=False,
139+
scaling_type_x=scaling_type_x,
140+
scaling_type_w=scaling_type_w,
141+
scaling_type_dL_dY=scaling_type_dL_dY,
142+
)
143+
scaling_repr = linear_float8.scaling_repr()
164144

165145
if fast_accum:
166146
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -172,19 +152,10 @@ def main(
172152
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
173153
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
174154

175-
if linear_type_enum == LinearType.DELAYED:
176-
177-
def float8_forw_backward():
178-
if linear_requires_sync(
179-
linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY
180-
):
181-
sync_float8_amax_and_scale_history(linear_float8)
182-
linear_float8(input_tensor).sum().backward()
183-
184-
else:
185-
186-
def float8_forw_backward():
187-
linear_float8(input_tensor).sum().backward()
155+
def float8_forw_backward():
156+
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
157+
sync_float8_amax_and_scale_history(linear_float8)
158+
linear_float8(input_tensor).sum().backward()
188159

189160
def n_times(n, fn, *args, **kwargs):
190161
def wrapper(*args, **kwargs):
@@ -224,7 +195,6 @@ def wrapper(*args, **kwargs):
224195
dtype,
225196
compile,
226197
use_fast_accum=fast_accum,
227-
linear_type=linear_type,
228198
scaling_repr=scaling_repr,
229199
)
230200
print(experiment)
@@ -237,7 +207,6 @@ def wrapper(*args, **kwargs):
237207
"M",
238208
"K",
239209
"N",
240-
"linear_type",
241210
"scaling_repr",
242211
"ref_dtype",
243212
"compiled",
@@ -257,7 +226,6 @@ def wrapper(*args, **kwargs):
257226
experiment.shape[0],
258227
experiment.shape[1],
259228
experiment.shape[2],
260-
experiment.linear_type,
261229
experiment.scaling_repr,
262230
experiment.dtype,
263231
experiment.compiled,
@@ -287,7 +255,6 @@ def wrapper(*args, **kwargs):
287255
[
288256
"name",
289257
"shape",
290-
"linear_type",
291258
"scaling_repr",
292259
"compiled",
293260
"use_fast_accum",
@@ -311,7 +278,6 @@ def invoke_main() -> None:
311278
parser.add_argument("-n", "--n_limit", type=int, required=False)
312279
parser.add_argument("--fast_accum_filter", type=bool, required=False)
313280
parser.add_argument("--shape_name_filter", type=str, required=False)
314-
parser.add_argument("--linear_type_filter", type=str, required=False)
315281
parser.add_argument("--scaling_type_x", type=str, required=False)
316282
parser.add_argument("--scaling_type_w", type=str, required=False)
317283
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
@@ -330,7 +296,6 @@ def invoke_main() -> None:
330296
args.n_limit,
331297
args.fast_accum_filter,
332298
args.shape_name_filter,
333-
args.linear_type_filter,
334299
**kwargs,
335300
)
336301

benchmarks/bench_multi_gpu.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.multiprocessing as mp
1515
import torch.nn as nn
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import Float8Linear
17+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1818
from float8_experimental.float8_linear_utils import (
1919
swap_linear_with_float8_linear,
2020
sync_float8_amax_and_scale_history,
@@ -65,7 +65,13 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
6565
modules.append(nn.ReLU())
6666
m = nn.Sequential(*modules)
6767
if is_fp8:
68-
swap_linear_with_float8_linear(m, Float8Linear, emulate=False)
68+
swap_linear_with_float8_linear(
69+
m,
70+
emulate=False,
71+
scaling_type_x=TensorScalingType.DELAYED,
72+
scaling_type_w=TensorScalingType.DELAYED,
73+
scaling_type_dL_dY=TensorScalingType.DELAYED,
74+
)
6975
return m
7076

7177

benchmarks/profile_linear_float8.py

+23-26
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
2221
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2322
from float8_experimental.float8_linear_utils import (
2423
linear_requires_sync,
25-
LinearType,
2624
swap_linear_with_float8_linear,
2725
sync_float8_amax_and_scale_history,
2826
)
@@ -206,19 +204,25 @@ def profile_function(
206204
def main(
207205
profile_path_prefix: Path,
208206
compile: bool = True,
209-
linear_type: str = "dynamic",
210-
scaling_type_x: str = "delayed",
211-
scaling_type_w: str = "delayed",
212-
scaling_type_dL_dY: str = "delayed",
207+
scaling_type_x: str = "dynamic",
208+
scaling_type_w: str = "dynamic",
209+
scaling_type_dL_dY: str = "dynamic",
213210
model_type: str = "linear",
214211
dtype_filter: str = "both",
215212
):
216213
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
217214
assert dtype_filter in ("both", "float8", "bfloat16")
218215

219-
print(f"Compile is set to | {compile}")
220-
print(f"Using Linear type: | {linear_type}")
221-
print(f"model_type is set to | {model_type}")
216+
scaling_type_x = TensorScalingType(scaling_type_x)
217+
scaling_type_w = TensorScalingType(scaling_type_w)
218+
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
219+
scaling_repr = "_".join(
220+
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
221+
)
222+
223+
print(f"Compile is set to | {compile}")
224+
print(f"model_type is set to | {model_type}")
225+
print(f"scaling_repr is set to | {scaling_repr}")
222226

223227
device = "cuda"
224228
ref_dtype = torch.bfloat16
@@ -249,21 +253,14 @@ def main(
249253

250254
m_ref = m_ref.to(device).to(ref_dtype)
251255

252-
linear_type = LinearType[linear_type.upper()]
253-
linear_cls = (
254-
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
255-
)
256-
extra_kwargs = {}
257-
scaling_type_x = TensorScalingType(scaling_type_x)
258-
scaling_type_w = TensorScalingType(scaling_type_w)
259-
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
260-
if linear_type is LinearType.DELAYED:
261-
extra_kwargs["scaling_type_x"] = scaling_type_x
262-
extra_kwargs["scaling_type_w"] = scaling_type_w
263-
extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY
256+
extra_kwargs = {
257+
"scaling_type_x": scaling_type_x,
258+
"scaling_type_w": scaling_type_w,
259+
"scaling_type_dL_dY": scaling_type_dL_dY,
260+
}
264261

265262
m_float8 = copy.deepcopy(m_ref)
266-
swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs)
263+
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
267264

268265
def ref_forw_backward(x):
269266
out = m_ref(x)
@@ -281,9 +278,7 @@ def float8_forw_backward_wrapper(x):
281278
# inspection of the fw+bw torch.compile without the scale
282279
# syncing code
283280
# TODO(future): make this better
284-
if linear_requires_sync(
285-
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
286-
):
281+
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
287282
with record_function("scale_amax_and_scales"):
288283
sync_amax_history(m_float8)
289284
out = float8_forw(x)
@@ -345,7 +340,9 @@ def float8_forw_backward_wrapper(x):
345340
if dtype_filter != "bfloat16":
346341
# Profile Float8 Model
347342
print("profiling float8")
348-
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
343+
float8_suffix = (
344+
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
345+
)
349346
float8_path = profile_path_prefix + float8_suffix
350347
profile_config = ProfileConfig(
351348
float8_path,

float8_experimental/float8_dynamic_linear.py float8_experimental/float8_dynamic_utils.py

-58
Original file line numberDiff line numberDiff line change
@@ -53,64 +53,6 @@ def backward(ctx, gradY):
5353
return fp8_tensor, None
5454

5555

56-
class Float8DynamicLinear(torch.nn.Linear):
57-
"""
58-
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
59-
conversion to fp8 of the input and weight tensors.
60-
"""
61-
62-
def __init__(self, **super_kwargs):
63-
super().__init__(**super_kwargs)
64-
65-
def forward(self, input: torch.Tensor) -> torch.Tensor:
66-
x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config)
67-
if isinstance(self.weight, Float8Tensor): # cast by FSDP
68-
w_fp8 = self.weight
69-
else:
70-
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
71-
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
72-
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
73-
return y
74-
75-
@classmethod
76-
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
77-
"""
78-
Create an nn.Linear with fp8 compute from a regular nn.Linear
79-
80-
Args:
81-
mod (torch.nn.Linear): nn.Linear to convert
82-
emulate (bool): whether to emulate fp8 matmul logic in float32
83-
"""
84-
with torch.device("meta"):
85-
super_kwargs = {
86-
"in_features": mod.in_features,
87-
"out_features": mod.out_features,
88-
"bias": False,
89-
}
90-
new_mod = cls(**super_kwargs)
91-
92-
new_mod.forward_config = ScaledMMConfig(
93-
emulate=emulate,
94-
use_fast_accum=not bool(emulate),
95-
fp8_output=False,
96-
pad_inner_dim=config.pad_inner_dim,
97-
)
98-
new_mod.backward_config = ScaledMMConfig(
99-
emulate=emulate,
100-
use_fast_accum=False,
101-
fp8_output=False,
102-
pad_inner_dim=config.pad_inner_dim,
103-
)
104-
if config.enable_fsdp_fp8_all_gather:
105-
new_mod.weight = nn.Parameter(
106-
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
107-
)
108-
else:
109-
new_mod.weight = mod.weight
110-
new_mod.bias = mod.bias
111-
return new_mod
112-
113-
11456
def cast_to_float8_e4m3_dynamic(
11557
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
11658
) -> Float8Tensor:

float8_experimental/float8_linear.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from float8_experimental.float8_dynamic_linear import (
19+
from float8_experimental.float8_dynamic_utils import (
2020
cast_to_float8_e4m3_dynamic,
2121
cast_to_float8_e5m2_dynamic_bw,
2222
WeightWithDynamicFloat8CastTensor,
@@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
402402

403403
def scaling_repr(self):
404404
# add scaling settings without using too many characters
405-
# example: "x:del,w:del,dldy:dyn"
406-
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
405+
# example: "x_del_w_del_dldy_dyn"
406+
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}"
407407

408408
def extra_repr(self):
409409
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'

0 commit comments

Comments
 (0)