diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index b84c2e9..b778c53 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -4,6 +4,7 @@ cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, ) +from float8_experimental.float8_linear import TensorScalingType from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( @@ -19,7 +20,17 @@ # here is that in input/output handling we do casting after # creating the DTensor. -# NOTE: This only works and tested with the DynamicLinear +# NOTE: This only works and tested with the dynamic scaling +# (Float8DynamicLinear and Float8Linear with dynamic scaling for all tensors) + + +def _float8_linear_supports_float8_allgather(m): + # TODO(future): add support for delayed scaling for activations + # and gradients + return ( + m.scaling_type_x == TensorScalingType.DYNAMIC + and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC + ) class Float8ColwiseParallel(ColwiseParallel): @@ -61,11 +72,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + from float8_experimental.float8_linear import Float8Linear - if not isinstance(module, Float8DynamicLinear): + if not isinstance(module, (Float8DynamicLinear, Float8Linear)): raise ValueError( - f"Expecting module to be Float8DynamicLinear but found {type(module)}" + f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}" ) + elif isinstance( + module, Float8Linear + ) and not _float8_linear_supports_float8_allgather(module): + raise AssertionError("unsupported") return super()._apply(module, device_mesh) @@ -107,11 +123,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + from float8_experimental.float8_linear import Float8Linear - if not isinstance(module, Float8DynamicLinear): + if not isinstance(module, (Float8DynamicLinear, Float8Linear)): raise ValueError( - f"Expecting module to be Float8DynamicLinear but found {type(module)}" + f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}" ) + elif isinstance( + module, Float8Linear + ) and not _float8_linear_supports_float8_allgather(module): + raise AssertionError("unsupported") return super()._apply(module, device_mesh) @@ -184,22 +205,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + from float8_experimental.float8_linear import Float8Linear fwd_linear_config = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) - assert isinstance(fwd_linear, Float8DynamicLinear) + assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear)) fwd_linear_config = fwd_linear.forward_config else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): - if isinstance(mod, Float8DynamicLinear): + if isinstance(mod, (Float8DynamicLinear, Float8Linear)): if fwd_linear_config is None: fwd_linear_config = mod.forward_config else: assert ( fwd_linear_config == mod.forward_config - ), "All the Float8DynamicLinear modules should have same forward config!" + ), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!" self.fwd_linear_config = fwd_linear_config super()._apply(module, device_mesh) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 354f831..24a5e58 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -18,6 +18,7 @@ Float8DynamicLinear, NoopFwToFloat8E5M2Bw, ) +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_tensor_parallel import ( @@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False +def _test_fp8_mlp_tensor_parallelism_base( + mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False ): device = mesh.device_type + # TODO(future): delete Float8DynamicLinear from this test once all the + # code is unified + float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear + extra_kwargs = {} + if use_float8_linear: + # For now, just use Float8Linear with dynamic scaling, which is the + # same behavior as Float8Linear. + # TODO(future): add support for float8 all-gather with delayed scaling + # for activations and gradients. + extra_kwargs = { + "scaling_type_x": TensorScalingType.DYNAMIC, + "scaling_type_w": TensorScalingType.DYNAMIC, + "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + } toy_model = ToyModel().to(device) toy_model_fp8 = swap_linear_with_float8_linear( - toy_model, Float8DynamicLinear, emulate=True + toy_model, float8_cls, emulate=True, **extra_kwargs ) tp_model = copy.deepcopy(toy_model) tp_model = swap_linear_with_float8_linear( - tp_model, Float8DynamicLinear, emulate=True + tp_model, float8_cls, emulate=True, **extra_kwargs ) sp_model = copy.deepcopy(toy_model) sp_model = swap_linear_with_float8_linear( - sp_model, Float8DynamicLinear, emulate=True + sp_model, float8_cls, emulate=True, **extra_kwargs ) # vanilla TP @@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base( # PrepareFloat8ModuleInput with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) sp_model2 = swap_linear_with_float8_linear( - sp_model2, Float8DynamicLinear, emulate=True + sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs ) sp_model2 = parallelize_module( @@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base( ) +def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=False, use_float8_linear=False + ) + + +def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=False, use_float8_linear=True + ) + + def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=True, use_float8_linear=False + ) + + +def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=True, use_float8_linear=True + ) if __name__ == "__main__": @@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): test_fp8_redistribute, test_dtensor_cast_to_fp8, test_dtensor_fp8_autograd, - test_fp8_mlp_tensor_parallelism_base, + test_fp8_mlp_tensor_parallelism_eager, + test_fp8_mlp_tensor_parallelism_eager_float8_linear, test_fp8_mlp_tensor_parallelism_compile, + test_fp8_mlp_tensor_parallelism_compile_float8_linear, ] for test in tqdm(tests, desc="Running tests"):