diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 31a6d64a20d..d97327de42d 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -35,6 +35,14 @@ import torch_xla.core.xla_model as xm +def is_initialized() -> bool: + """ + Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`, + but works as a module method. + """ + return AcceleratorState._shared_state != {} + + # Inspired by Alex Martelli's 'Borg'. class AcceleratorState: """ @@ -45,6 +53,7 @@ class AcceleratorState: - **device** (`torch.device`) -- The device to use. - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently in use. + - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`. - **local_process_index** (`int`) -- The index of the current process on the current server. - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type of mixed precision being performed. @@ -69,8 +78,7 @@ def __init__( if parse_flag_from_env("ACCELERATE_USE_CPU"): cpu = True self._check_initialized(mixed_precision, cpu) - self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) - if not getattr(self, "initialized", False): + if not self.initialized: self.backend = None self.deepspeed_plugin = None mixed_precision = ( @@ -245,18 +253,17 @@ def __init__( and self.device.type == "cuda" ): torch.backends.cuda.matmul.allow_tf32 = True - self.initialized = True - def __repr__(self): - mixed_precision = self.mixed_precision + self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) + def __repr__(self): repr = ( f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n" f"Num processes: {self.num_processes}\n" f"Process index: {self.process_index}\n" f"Local process index: {self.local_process_index}\n" f"Device: {self.device}\n" - f"Mixed precision type: {mixed_precision}\n" + f"Mixed precision type: {self.mixed_precision}\n" ) if self.distributed_type == DistributedType.DEEPSPEED: repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n" @@ -286,9 +293,14 @@ def _reset_state(): "Resets `_shared_state`, is used internally and should not be called" AcceleratorState._shared_state = {} + @property + def initialized(self) -> bool: + "Returns whether the `AcceleratorState` has been initialized" + return self._shared_state != {} + def _check_initialized(self, mixed_precision=None, cpu=None): "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" - if getattr(self, "initialized", False): + if self.initialized: err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerate()`." if cpu and self.device.type != "cpu": raise ValueError(err.format(flag="cpu=True")) @@ -311,11 +323,15 @@ class GradientState: def __init__(self): self.__dict__ = self._shared_state - if not getattr(self, "initialized", False): + if not self.initialized: self.sync_gradients = True self.end_of_dataloader = False self.remainder = -1 - self.initialized = True + + @property + def initialized(self) -> bool: + "Returns whether the `GradientState` has been initialized" + return GradientState._shared_state != {} def __repr__(self): return (