Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[9/x]: make dynamic scaling default in Float8Linear #300

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,23 @@ pip install -e ".[dev]"

# User API

We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).

## float8 linear with dynamic scaling
## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`

This is the most accurate recipe as every tensor is scaled dynamically.

```python
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

# create model
m = Model(...)

# convert all `torch.nn.Linear` modules to `Float8DynamicLinear`
swap_linear_with_float8_linear(m, Float8DynamicLinear)
# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, Float8Linear)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
Expand All @@ -54,18 +56,27 @@ m = torch.compile(m)

## float8 linear with delayed scaling

This is theoretically the most performant recipe as it minimizes memory reads.

```python
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

# create model
m = Model(...)

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, Float8Linear)
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
swap_linear_with_float8_linear(
m,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)

# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
Expand Down Expand Up @@ -93,9 +104,7 @@ for _ in range(N_ITER):
# 🧭 Code Organization

* `float8_experimental/float8_linear.py`
- `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py`
- `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
- `Float8Linear` (main user facing entry point for Float8Linear)
* `float8_experimental/float8_tensor.py`
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def swap_linear_with_float8_linear(
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
scaling_type_x: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_w: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
) -> Optional[nn.Module]:
"""
Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`.
Expand Down
16 changes: 14 additions & 2 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ def test_sync_amax_func():
module = torch.nn.Sequential(
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
)
float8_mod = swap_linear_with_float8_linear(module, Float8Linear)
float8_mod = swap_linear_with_float8_linear(
module,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
compiled_swap_func(float8_mod)
assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"
Expand Down Expand Up @@ -329,7 +335,13 @@ def test_sync_amax_func_cuda_graph_success():
my_module = nn.Sequential(
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
).to("cuda")
swap_linear_with_float8_linear(my_module, Float8Linear)
swap_linear_with_float8_linear(
my_module,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
inpt = torch.randn(
16, 16, device="cuda", dtype=torch.float32, requires_grad=True
)
Expand Down
9 changes: 8 additions & 1 deletion test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
Expand Down Expand Up @@ -130,7 +132,12 @@ def forward_backward(model, optim, is_fp8, i):
optim.zero_grad()
y_local = model(ref_input_local[i])
y_local.backward(ref_grad_local[i])
if is_fp8:
if is_fp8 and linear_requires_sync(
LinearType.DELAYED,
TensorScalingType.DYNAMIC,
scaling_type_w,
TensorScalingType.DYNAMIC,
):
sync_float8_func(model)
optim.step()
return y_local
Expand Down
11 changes: 9 additions & 2 deletions test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental import config
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -49,7 +49,14 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
)
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
swap_linear_with_float8_linear(
m,
Float8Linear,
emulate=emulate,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
return m


Expand Down
Loading