Skip to content

Commit

Permalink
Support TPU v2 and v3 on new PyTorch/XLA TPU runtime (#1385)
Browse files Browse the repository at this point in the history
* Use numpy Generator instead of global seed

* Implement SharedDict descriptor

* Formatting and comments

* Remove `GlobalSharedDict`

* Formatting

* Formatting with `doc-builder` installed correctly
  • Loading branch information
will-cromar authored May 9, 2023
1 parent fafadc5 commit d95d68e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
44 changes: 38 additions & 6 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import threading
import warnings
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -54,6 +55,37 @@ def do_nothing(*args, **kwargs):
return None


class ThreadLocalSharedDict(threading.local):
"""
Descriptor that holds a dict shared between instances of a class in the same thread.
Note: Descriptors have slightly different semantics than just a dict field on its own.
`PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
"""

def __init__(self, thread_local: bool = False):
self._storage = {}

def __get__(self, obj, objtype=None):
return self._storage

def __set__(self, obj, value):
self._storage = value


# Prefer global shared dictionary, except when using TPU.
SharedDict = dict if not is_tpu_available(check_device=False) else ThreadLocalSharedDict


# Inspired by Alex Martelli's 'Borg'.
class PartialState:
"""
Expand All @@ -76,7 +108,7 @@ class PartialState:
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
"""

_shared_state = {}
_shared_state = SharedDict()

def __init__(self, cpu: bool = False, **kwargs):
self.__dict__ = self._shared_state
Expand Down Expand Up @@ -211,7 +243,7 @@ def __repr__(self) -> str:
@staticmethod
def _reset_state():
"Resets `_shared_state`, is used internally and should not be called"
PartialState._shared_state = {}
PartialState._shared_state.clear()

@property
def initialized(self) -> bool:
Expand Down Expand Up @@ -528,7 +560,7 @@ class AcceleratorState:
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
"""

_shared_state = {}
_shared_state = SharedDict()

def __init__(
self,
Expand Down Expand Up @@ -652,7 +684,7 @@ def mixed_precision(self):
@staticmethod
def _reset_state(reset_partial_state: bool = False):
"Resets `_shared_state`, is used internally and should not be called"
AcceleratorState._shared_state = {}
AcceleratorState._shared_state.clear()
if reset_partial_state:
PartialState._reset_state()

Expand Down Expand Up @@ -722,7 +754,7 @@ class GradientState:
accumulation
"""

_shared_state = {}
_shared_state = SharedDict()

def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None):
self.__dict__ = self._shared_state
Expand Down Expand Up @@ -793,4 +825,4 @@ def in_dataloader(self) -> bool:
@staticmethod
def _reset_state():
"Resets `_shared_state`, is used internally and should not be called"
GradientState._shared_state = {}
GradientState._shared_state.clear()
2 changes: 1 addition & 1 deletion src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def central_dl_preparation_check():
def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length)
train_set = RegressionDataset(length=length, seed=42)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand Down
7 changes: 3 additions & 4 deletions src/accelerate/test_utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@

class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=None):
if seed is not None:
np.random.seed(seed)
rng = np.random.default_rng(seed)
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)).astype(np.float32)
self.x = rng.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)

def __len__(self):
return self.length
Expand Down

0 comments on commit d95d68e

Please sign in to comment.