Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[4/x] add tests for DTensor TP/SP + Float8Linear
Browse files Browse the repository at this point in the history
Summary:

Makes the DTensor TP/SP tests also test `Float8Linear` with all scaling
types configured to be dynamic.

We can add support for delayed scaling with float8 all-gather for `x`
and `dL_dY` in a future PR, as needed.

Test Plan:

```
./test/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8e121ecaf4e05c5eb69b1612c084c459abe95589
Pull Request resolved: #294
  • Loading branch information
vkuzo committed Jul 1, 2024
1 parent a0ad964 commit 54364a8
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
35 changes: 28 additions & 7 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
)
from float8_experimental.float8_linear import TensorScalingType
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
Expand All @@ -22,6 +23,15 @@
# NOTE: This only works and tested with the DynamicLinear


def _float8_linear_supports_float8_allgather(m):
# TODO(future PR): add support for delayed scaling for activations
# and gradients
return (
m.scaling_type_x == TensorScalingType.DYNAMIC
and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC
)


class Float8ColwiseParallel(ColwiseParallel):
@staticmethod
def _prepare_input_fn(
Expand Down Expand Up @@ -61,11 +71,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

if not isinstance(module, Float8DynamicLinear):
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
raise ValueError(
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
)
elif isinstance(
module, Float8Linear
) and not _float8_linear_supports_float8_allgather(module):
raise AssertionError("unsupported")

return super()._apply(module, device_mesh)

Expand Down Expand Up @@ -107,11 +122,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

if not isinstance(module, Float8DynamicLinear):
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
raise ValueError(
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
)
elif isinstance(
module, Float8Linear
) and not _float8_linear_supports_float8_allgather(module):
raise AssertionError("unsupported")

return super()._apply(module, device_mesh)

Expand Down Expand Up @@ -184,22 +204,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

fwd_linear_config = None
if self.fwd_config_submodule_fqn is not None:
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
assert isinstance(fwd_linear, Float8DynamicLinear)
assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear))
fwd_linear_config = fwd_linear.forward_config
else:
# search for ScaledMM configs for all the submodules and make sure they are the same
for mod in module.modules():
if isinstance(mod, Float8DynamicLinear):
if isinstance(mod, (Float8DynamicLinear, Float8Linear)):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8DynamicLinear modules should have same forward config!"
), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!"

self.fwd_linear_config = fwd_linear_config
super()._apply(module, device_mesh)
Expand Down
53 changes: 45 additions & 8 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Float8DynamicLinear,
NoopFwToFloat8E5M2Bw,
)
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor_parallel import (
Expand Down Expand Up @@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
loss.backward()


def test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False
def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False
):
device = mesh.device_type
# TODO(future): delete Float8DynamicLinear from this test once all the
# code is unified
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
extra_kwargs = {}
if use_float8_linear:
# For now, just use Float8Linear with dynamic scaling, which is the
# same behavior as Float8Linear.
# TODO(future): add support for float8 all-gather with delayed scaling
# for activations and gradients.
extra_kwargs = {
"scaling_type_x": TensorScalingType.DYNAMIC,
"scaling_type_w": TensorScalingType.DYNAMIC,
"scaling_type_dL_dY": TensorScalingType.DYNAMIC,
}

toy_model = ToyModel().to(device)
toy_model_fp8 = swap_linear_with_float8_linear(
toy_model, Float8DynamicLinear, emulate=True
toy_model, float8_cls, emulate=True, **extra_kwargs
)

tp_model = copy.deepcopy(toy_model)
tp_model = swap_linear_with_float8_linear(
tp_model, Float8DynamicLinear, emulate=True
tp_model, float8_cls, emulate=True, **extra_kwargs
)
sp_model = copy.deepcopy(toy_model)
sp_model = swap_linear_with_float8_linear(
sp_model, Float8DynamicLinear, emulate=True
sp_model, float8_cls, emulate=True, **extra_kwargs
)

# vanilla TP
Expand Down Expand Up @@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base(
# PrepareFloat8ModuleInput with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = swap_linear_with_float8_linear(
sp_model2, Float8DynamicLinear, emulate=True
sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs
)

sp_model2 = parallelize_module(
Expand Down Expand Up @@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base(
)


def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=False, use_float8_linear=False
)


def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=False, use_float8_linear=True
)


def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=True, use_float8_linear=False
)


def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=True, use_float8_linear=True
)


if __name__ == "__main__":
Expand All @@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
test_fp8_redistribute,
test_dtensor_cast_to_fp8,
test_dtensor_fp8_autograd,
test_fp8_mlp_tensor_parallelism_base,
test_fp8_mlp_tensor_parallelism_eager,
test_fp8_mlp_tensor_parallelism_eager_float8_linear,
test_fp8_mlp_tensor_parallelism_compile,
test_fp8_mlp_tensor_parallelism_compile_float8_linear,
]

for test in tqdm(tests, desc="Running tests"):
Expand Down

0 comments on commit 54364a8

Please sign in to comment.