From 4a9c538838d6642e58acfb2e26dcc1b6a5a42538 Mon Sep 17 00:00:00 2001 From: Kiyoon Kim Date: Thu, 27 Apr 2023 13:59:06 +0100 Subject: [PATCH] fix: typing issues, and replace deprecated python typing (Optional, Union) to `|` --- README.md | 2 +- src/accelerate/accelerator.py | 48 ++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index ffb658d1883..fac8bb0ed4d 100644 --- a/README.md +++ b/README.md @@ -223,7 +223,7 @@ If you like the simplicity of 🤗 Accelerate but would prefer a higher-level ab ## Installation -This repository is tested on Python 3.6+ and PyTorch 1.4.0+ +This repository is tested on Python 3.7+ and PyTorch 1.4.0+ You should install 🤗 Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 7f035061611..ec5bc06ef29 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import contextlib import inspect import math @@ -23,7 +25,7 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable import torch import torch.utils.hooks as hooks @@ -198,7 +200,7 @@ class Accelerator: step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). - kwargs_handlers (`List[KwargHandler]`, *optional*) + kwargs_handlers (`list[KwargHandler]`, *optional*) A list of `KwargHandler` to customize how the objects related to distributed training or mixed precision are created. See [kwargs](kwargs) for more information. dynamo_backend (`str` or `DynamoBackend`, *optional*, defaults to `"no"`): @@ -227,24 +229,24 @@ def __init__( self, device_placement: bool = True, split_batches: bool = False, - mixed_precision: Union[PrecisionType, str] = None, + mixed_precision: PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, - deepspeed_plugin: DeepSpeedPlugin = None, - fsdp_plugin: FullyShardedDataParallelPlugin = None, - megatron_lm_plugin: MegatronLMPlugin = None, - ipex_plugin: IntelPyTorchExtensionPlugin = None, - rng_types: Optional[List[Union[str, RNGType]]] = None, - log_with: Optional[List[Union[str, LoggerType, GeneralTracker]]] = None, - project_dir: Optional[Union[str, os.PathLike]] = None, - project_config: Optional[ProjectConfiguration] = None, - logging_dir: Optional[Union[str, os.PathLike]] = None, - gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None, - dispatch_batches: Optional[bool] = None, + deepspeed_plugin: DeepSpeedPlugin | None = None, + fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + megatron_lm_plugin: MegatronLMPlugin | None = None, + ipex_plugin: IntelPyTorchExtensionPlugin | None = None, + rng_types: list[str | RNGType] | None = None, + log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, + project_dir: str | os.PathLike | None = None, + project_config: ProjectConfiguration | None = None, + logging_dir: str | os.PathLike | None = None, + gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, + dispatch_batches: bool | None = None, even_batches: bool = True, step_scheduler_with_optimizer: bool = True, - kwargs_handlers: Optional[List[KwargsHandler]] = None, - dynamo_backend: Union[DynamoBackend, str] = None, + kwargs_handlers: list[KwargsHandler] | None = None, + dynamo_backend: DynamoBackend | str | None = None, ): if project_config is not None: self.project_configuration = project_config @@ -890,7 +892,7 @@ def join_uneven_inputs(self, joinables, even_batches=None): length of the dataset. Args: - joinables (`List[torch.distributed.algorithms.Joinable]`): + joinables (`list[torch.distributed.algorithms.Joinable]`): A list of models or optimizers that subclass `torch.distributed.algorithms.Joinable`. Most commonly, a PyTorch Module that was prepared with `Accelerator.prepare` for DistributedDataParallel training. even_batches (`bool`, *optional*) @@ -1060,7 +1062,7 @@ def prepare(self, *args, device_placement=None): - `torch.optim.Optimizer`: PyTorch Optimizer - `torch.optim.lr_scheduler.LRScheduler`: PyTorch LR Scheduler - device_placement (`List[bool]`, *optional*): + device_placement (`list[bool]`, *optional*): Used to customize whether automatic device placement should be performed for each object passed. Needs to be a list of the same length as `args`. @@ -1748,7 +1750,7 @@ def unscale_gradients(self, optimizer=None): Likely should be called through [`Accelerator.clip_grad_norm_`] or [`Accelerator.clip_grad_value_`] Args: - optimizer (`torch.optim.Optimizer` or `List[torch.optim.Optimizer]`, *optional*): + optimizer (`torch.optim.Optimizer` or `list[torch.optim.Optimizer]`, *optional*): The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers that were passed to [`~Accelerator.prepare`]. @@ -2048,7 +2050,7 @@ def wait_for_everyone(self): wait_for_everyone() @on_main_process - def init_trackers(self, project_name: str, config: Optional[dict] = None, init_kwargs: Optional[dict] = {}): + def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}): """ Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations @@ -2128,7 +2130,7 @@ def get_tracker(self, name: str, unwrap: bool = False): return GeneralTracker(_blank=True) @on_main_process - def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dict] = {}): + def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): """ Logs `values` to all stored trackers in `self.trackers` on the main process only. @@ -2207,7 +2209,7 @@ def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov The hook should have the following signature: - `hook(models: List[torch.nn.Module], weights: List[Dict[str, torch.Tensor]], input_dir: str) -> None` + `hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None` The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weigths` argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed @@ -2353,7 +2355,7 @@ def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.Remov The hook should have the following signature: - `hook(models: List[torch.nn.Module], input_dir: str) -> None` + `hook(models: list[torch.nn.Module], input_dir: str) -> None` The `models` argument are the models as saved in the accelerator state under `accelerator._models`, and the `input_dir` argument is the `input_dir` argument passed to [`Accelerator.load_state`].