diff --git a/src/accelerate/state.py b/src/accelerate/state.py index d1d0dec12c5..726c0ea7247 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -78,13 +78,6 @@ class PartialState: def __init__(self, cpu: bool = False, **kwargs): self.__dict__ = self._shared_state - # Raise an error if the user tries to reinitialize on a different device setup in the same launch - if self.initialized and (self._cpu != cpu): - raise AssertionError( - "The current device and desired device are not the same. If the `PartialState` was generated " - "before the `Accelerator` has been instantiated, ensure the `cpu` flag is the same for both. In this case, " - f"the `PartialState` has {self._cpu} and the desired device is {cpu}. Please use `cpu={self._cpu}`." - ) if not self.initialized: self._cpu = cpu self.backend = None @@ -540,10 +533,12 @@ def __init__( **kwargs, ): self.__dict__ = self._shared_state - if PartialState._shared_state == {} or (cpu != PartialState._shared_state.get("_cpu", False)): + if parse_flag_from_env("ACCELERATE_USE_CPU"): + cpu = True + if PartialState._shared_state == {}: PartialState(cpu, **kwargs) self.__dict__.update(PartialState._shared_state) - self._check_initialized(mixed_precision) + self._check_initialized(mixed_precision, cpu) if not self.initialized: self.deepspeed_plugin = None mixed_precision = ( @@ -599,10 +594,12 @@ def __repr__(self): repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n" return repr - def _check_initialized(self, mixed_precision=None): + def _check_initialized(self, mixed_precision=None, cpu=None): "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" if self.initialized: err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`." + if cpu and self.device.type != "cpu": + raise ValueError(err.format(flag="cpu=True")) if ( mixed_precision is not None and mixed_precision != self._mixed_precision diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index d6b74fc1557..9c846639855 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -40,7 +40,7 @@ def test_accelerator_can_be_reinstantiated(self): _ = Accelerator() assert PartialState._shared_state["_cpu"] is False assert PartialState._shared_state["device"].type == "cuda" - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): _ = Accelerator(cpu=True) def test_prepared_objects_are_referenced(self): @@ -226,3 +226,10 @@ def test_accelerator_bnb_multi_gpu(self): # This should not work and get value error with self.assertRaises(ValueError): _ = accelerator.prepare(model) + + @require_cuda + def test_accelerator_cpu_flag_prepare(self): + model = torch.nn.Linear(10, 10) + sgd = torch.optim.SGD(model.parameters(), lr=0.01) + accelerator = Accelerator(cpu=True) + _ = accelerator.prepare(sgd)