Skip to content

Commit

Permalink
Minor strategy fixes [TPU] (#18774)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 11, 2023
1 parent 4df6e13 commit 5a83f54
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 184 deletions.
15 changes: 1 addition & 14 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,9 @@ def _init_strategy(self) -> None:
self.strategy = self._strategy_flag

def _check_and_init_precision(self) -> Precision:
self._validate_precision_choice()
if isinstance(self._precision_instance, Precision):
return self._precision_instance
if isinstance(self.accelerator, XLAAccelerator):
if isinstance(self.strategy, (SingleDeviceXLAStrategy, XLAStrategy, XLAFSDPStrategy)):
return XLAPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore
Expand Down Expand Up @@ -492,18 +491,6 @@ def _check_and_init_precision(self) -> Precision:

raise RuntimeError("No precision set")

def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if (
isinstance(self.accelerator, XLAAccelerator)
and self._precision_instance
and not isinstance(self._precision_instance, XLAPrecision)
):
raise ValueError(
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
f" found: {self._precision_instance}."
)

def _lazy_init_strategy(self) -> None:
"""Lazily set missing attributes on the previously instantiated strategy."""
self.strategy.accelerator = self.accelerator
Expand Down
24 changes: 7 additions & 17 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,25 +305,15 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
flattened parameters.
"""
if _TORCH_GREATER_EQUAL_2_0:
return optimizer

from torch.distributed.fsdp import FlatParameter

num_groups = len(optimizer.param_groups)
if num_groups > 1:
if self._fsdp_kwargs.get("use_orig_params"):
return super().setup_optimizer(optimizer)
if not _optimizer_has_flat_params(optimizer):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
raise ValueError(
"An optimizer used with an FSDP model does not support multiple param groups."
f" Found {num_groups} parameter groups."
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
)

if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]):
return optimizer

raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
)
return optimizer

def module_to_device(self, module: Module) -> None:
pass
Expand Down
37 changes: 27 additions & 10 deletions src/lightning/fabric/strategies/single_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
from lightning.fabric.utilities.types import _DEVICE
Expand All @@ -32,8 +31,8 @@ def __init__(
self,
device: _DEVICE,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
):
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
Expand All @@ -50,16 +49,34 @@ def __init__(
precision=precision,
)

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@property # type: ignore[override]
def checkpoint_io(self) -> XLACheckpointIO:
plugin = self._checkpoint_io
if plugin is not None:
assert isinstance(plugin, XLACheckpointIO)
return plugin
return XLACheckpointIO()

@checkpoint_io.setter
def checkpoint_io(self, io: CheckpointIO) -> None:
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io

@property # type: ignore[override]
def precision(self) -> XLAPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, XLAPrecision)
return plugin
return XLAPrecision("32-true")

@precision.setter
def precision(self, precision: Optional[XLAPrecision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision

@classmethod
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("single_xla", cls, description=cls.__name__)
38 changes: 27 additions & 11 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, _using_pjrt
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
Expand All @@ -44,8 +43,8 @@ def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
sync_module_states: bool = True,
) -> None:
super().__init__(
Expand All @@ -55,7 +54,6 @@ def __init__(
checkpoint_io=checkpoint_io,
precision=precision,
)
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
self._launched = False
self._sync_module_states = sync_module_states
Expand All @@ -72,16 +70,34 @@ def root_device(self) -> torch.device:
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@property # type: ignore[override]
def checkpoint_io(self) -> XLACheckpointIO:
plugin = self._checkpoint_io
if plugin is not None:
assert isinstance(plugin, XLACheckpointIO)
return plugin
return XLACheckpointIO()

@checkpoint_io.setter
def checkpoint_io(self, io: CheckpointIO) -> None:
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io

@property # type: ignore[override]
def precision(self) -> XLAPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, XLAPrecision)
return plugin
return XLAPrecision("32-true")

@precision.setter
def precision(self, precision: Optional[XLAPrecision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision

@property
def global_rank(self) -> int:
return super().global_rank if self._launched else 0
Expand Down
53 changes: 29 additions & 24 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from torch.utils.data import DataLoader

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
from lightning.fabric.strategies.fsdp import _apply_filter
Expand Down Expand Up @@ -85,22 +84,23 @@ def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
auto_wrap_policy: Optional[_POLICY] = None,
activation_checkpointing_policy: Optional[_POLICY_SET] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
sequential_save: bool = False,
**kwargs: Any,
) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=XLAEnvironment(),
checkpoint_io=checkpoint_io,
precision=precision,
)
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = _XLAFSDPBackwardSyncControl()

self._auto_wrap_policy = auto_wrap_policy
Expand All @@ -122,16 +122,34 @@ def root_device(self) -> torch.device:
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@property # type: ignore[override]
def checkpoint_io(self) -> XLACheckpointIO:
plugin = self._checkpoint_io
if plugin is not None:
assert isinstance(plugin, XLACheckpointIO)
return plugin
return XLACheckpointIO()

@checkpoint_io.setter
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io

@property # type: ignore[override]
def precision(self) -> XLAPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, XLAPrecision)
return plugin
return XLAPrecision("32-true")

@precision.setter
def precision(self, precision: Optional[XLAPrecision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA FSDP strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision

@property
def global_rank(self) -> int:
return super().global_rank if self._launched else 0
Expand Down Expand Up @@ -227,21 +245,8 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
flattened parameters.
"""
if _TORCH_GREATER_EQUAL_2_0:
return optimizer

from torch_xla.distributed.fsdp.xla_flatten_params_wrapper import FlatParameter

num_groups = len(optimizer.param_groups)
if num_groups > 1:
raise ValueError(
"An optimizer used with an XLAFSDP model does not support multiple param groups."
f" Found {num_groups} parameter groups."
)

if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]):
if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]):
return optimizer

raise ValueError(
"The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
Expand Down
11 changes: 1 addition & 10 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)

invalid_params_error = False
try:
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
raise
invalid_params_error = True

if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
Expand Down
11 changes: 6 additions & 5 deletions src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
import os
import queue
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch.multiprocessing as mp

import lightning.pytorch as pl
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
from lightning.fabric.strategies.launchers.xla import _rank_teardown
from lightning.fabric.utilities import move_data_to_device
Expand All @@ -29,6 +28,9 @@
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.rank_zero import rank_zero_debug

if TYPE_CHECKING:
import lightning.pytorch as pl


class _XLALauncher(_MultiProcessingLauncher):
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
Expand Down Expand Up @@ -145,12 +147,11 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
else None
)

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()

# save the last weights
weights_path = None
if trainer.state.fn == TrainerFn.FITTING:
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self._strategy.lightning_module_state_dict()
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)

Expand Down
Loading

0 comments on commit 5a83f54

Please sign in to comment.