diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f6015cd234a5f..b20549a57a5ad 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed an issue with `LightningModule.*_step` methods bypassing the DDP/FSDP wrapper ([#17424](https://github.com/Lightning-AI/lightning/pull/17424)) ## [2.0.1] - 2023-03-30 diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index dc503dc29d1a7..58533a9c6043c 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union import torch +from lightning_utilities import WarningCache from lightning_utilities.core.apply_func import apply_to_collection from torch import nn as nn from torch import Tensor @@ -28,8 +29,11 @@ from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.types import Optimizable +from lightning.fabric.utilities.warnings import PossibleUserWarning +warning_cache = WarningCache() T_destination = TypeVar("T_destination", bound=Dict[str, Any]) +_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step") class _FabricOptimizer: @@ -132,7 +136,42 @@ def state_dict( def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> _IncompatibleKeys: return self._original_module.load_state_dict(state_dict=state_dict, strict=strict) + def _redirection_through_forward(self, method_name: str) -> Callable: + assert method_name != "forward" + original_forward = self._original_module.forward + + def wrapped_forward(*args: Any, **kwargs: Any) -> Any: + # Unpatch ourselves immediately before calling the method `method_name` + # because itself may want to call the real `forward` + self._original_module.forward = original_forward + # Call the actual method e.g. `.training_step(...)` + method = getattr(self._original_module, method_name) + return method(*args, **kwargs) + + # We make the caller "unknowingly" send their arguments through the forward_module's `__call__`. + # We expect that the `forward_module` will eventually call `original_module.forward`, which we + # have patched to redirect back to `original_module.method_name()`. + def call_forward_module(*args: Any, **kwargs: Any) -> Any: + # Patch the original_module's forward so we can redirect the arguments back to the real method + self._original_module.forward = wrapped_forward + return self._forward_module(*args, **kwargs) + + return call_forward_module + + def _validate_method_access(self, name: str, attribute: Any) -> None: + if inspect.ismethod(attribute) and self._forward_module != self._original_module: + warning_cache.warn( + f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the" + " model. This will bypass the wrapper from the strategy and result in incorrect behavior in" + f" `.backward()`. You should pass your inputs through `{type(self._original_module)}.forward()`.", + category=PossibleUserWarning, + ) + def __getattr__(self, item: Any) -> Any: + if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module: + # Special support for `LightningModule`, to prevent bypassing DDP's forward + return self._redirection_through_forward(item) + try: # __getattr__ gets called as a last resort if the attribute does not exist # call nn.Module's implementation first @@ -140,7 +179,9 @@ def __getattr__(self, item: Any) -> Any: except AttributeError: # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module original_module = super().__getattr__("_original_module") - return getattr(original_module, item) + attr = getattr(original_module, item) + self._validate_method_access(item, attr) + return attr class _FabricDataLoader: diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index acd4f41e6e3e9..ac14f5c070507 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -16,6 +16,7 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call from torch.utils.data import BatchSampler, DistributedSampler from torch.utils.data.dataloader import DataLoader @@ -59,13 +60,43 @@ def __init__(self): fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module) assert fabric_module.attribute == 1 assert fabric_module.layer is original_module.layer - assert fabric_module.method() == 2 assert fabric_module.forward.__self__.__class__ == _FabricModule with pytest.raises(AttributeError): _ = fabric_module.not_exists +def test_fabric_module_method_lookup(): + """Test that access to methods warns about improper use when a wrapper from a strategy is involved.""" + from lightning.fabric.wrappers import warning_cache + + class OriginalModule(torch.nn.Module): + def method(self): + return 100 + + class ModuleWrapper(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.wrapped = module + + # Regular case: forward_module == original_module -> no warnings + original_module = OriginalModule() + fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module) + warning_cache.clear() + with no_warning_call(UserWarning): + assert fabric_module.method() == 100 + assert not warning_cache + + # Special case: original module wrapped by forward module: -> warn + original_module = OriginalModule() + wrapped_module = ModuleWrapper(original_module) + fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module) + warning_cache.clear() + with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method\(\)` from outside the"): + assert fabric_module.method() == 100 + warning_cache.clear() + + def test_fabric_module_state_dict_access(): """Test that state_dict access passes through to the original module.""" @@ -353,3 +384,68 @@ def test_is_wrapped(): assert not is_wrapped(dataloader) wrapped = _FabricDataLoader(dataloader) assert is_wrapped(wrapped) + + +def test_step_method_redirection(): + """Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward- + module.""" + + class DDP(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + class LightningModule(torch.nn.Module): + def forward(self): + return "forward_return" + + def training_step(self, arg, kwarg=None): + assert self() == "forward_return" + assert arg == "train_arg" + assert kwarg == "train_kwarg" + return "training_step_return" + + def validation_step(self, arg, kwarg=None): + assert self() == "forward_return" + assert arg == "val_arg" + assert kwarg == "val_kwarg" + return "validation_step_return" + + def normal_method(self): + pass + + original_module = LightningModule() + forward_module = DDP(original_module) + fabric_module = _FabricModule(forward_module=forward_module, precision=Mock(), original_module=original_module) + + # Regular methods on the original_module are visible and identical on the fabric_module ... + assert fabric_module.normal_method == original_module.normal_method + + # ... but special methods like training_step get redirected to the forward_module + assert fabric_module.training_step.__name__ == "call_forward_module" + assert fabric_module.validation_step.__name__ == "call_forward_module" + assert fabric_module.test_step.__name__ == "call_forward_module" + assert fabric_module.predict_step.__name__ == "call_forward_module" + + with pytest.raises(AttributeError, match="has no attribute 'predict_step'"): + # A special method that does not exist will raise its AttributeError when being called + fabric_module.predict_step() + + # The forward method on the original module remains untouched + assert original_module.forward.__name__ == "forward" + + # The special methods get redirected correctly to produce the expected output + assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" + assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time + assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return" + + # The forward method remains untouched/unpatched after the special methods have been called + assert original_module.forward.__name__ == "forward" + + # Special case: forward_module == original_module -> no special treatment applied + fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module) + assert fabric_module.training_step == original_module.training_step + assert fabric_module.validation_step == original_module.validation_step