Skip to content

Commit

Permalink
Add is_initialized method and refactor (#949)
Browse files Browse the repository at this point in the history
* Add is_initialized method and refactor

* As module method
  • Loading branch information
muellerzr authored Jan 3, 2023
1 parent e60f3ca commit bf8fe03
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
import torch_xla.core.xla_model as xm


def is_initialized() -> bool:
"""
Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
but works as a module method.
"""
return AcceleratorState._shared_state != {}


# Inspired by Alex Martelli's 'Borg'.
class AcceleratorState:
"""
Expand All @@ -45,6 +53,7 @@ class AcceleratorState:
- **device** (`torch.device`) -- The device to use.
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
in use.
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
- **local_process_index** (`int`) -- The index of the current process on the current server.
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
of mixed precision being performed.
Expand All @@ -69,8 +78,7 @@ def __init__(
if parse_flag_from_env("ACCELERATE_USE_CPU"):
cpu = True
self._check_initialized(mixed_precision, cpu)
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
if not getattr(self, "initialized", False):
if not self.initialized:
self.backend = None
self.deepspeed_plugin = None
mixed_precision = (
Expand Down Expand Up @@ -245,18 +253,17 @@ def __init__(
and self.device.type == "cuda"
):
torch.backends.cuda.matmul.allow_tf32 = True
self.initialized = True

def __repr__(self):
mixed_precision = self.mixed_precision
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)

def __repr__(self):
repr = (
f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
f"Num processes: {self.num_processes}\n"
f"Process index: {self.process_index}\n"
f"Local process index: {self.local_process_index}\n"
f"Device: {self.device}\n"
f"Mixed precision type: {mixed_precision}\n"
f"Mixed precision type: {self.mixed_precision}\n"
)
if self.distributed_type == DistributedType.DEEPSPEED:
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
Expand Down Expand Up @@ -286,9 +293,14 @@ def _reset_state():
"Resets `_shared_state`, is used internally and should not be called"
AcceleratorState._shared_state = {}

@property
def initialized(self) -> bool:
"Returns whether the `AcceleratorState` has been initialized"
return self._shared_state != {}

def _check_initialized(self, mixed_precision=None, cpu=None):
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
if getattr(self, "initialized", False):
if self.initialized:
err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerate()`."
if cpu and self.device.type != "cpu":
raise ValueError(err.format(flag="cpu=True"))
Expand All @@ -311,11 +323,15 @@ class GradientState:

def __init__(self):
self.__dict__ = self._shared_state
if not getattr(self, "initialized", False):
if not self.initialized:
self.sync_gradients = True
self.end_of_dataloader = False
self.remainder = -1
self.initialized = True

@property
def initialized(self) -> bool:
"Returns whether the `GradientState` has been initialized"
return GradientState._shared_state != {}

def __repr__(self):
return (
Expand Down

0 comments on commit bf8fe03

Please sign in to comment.