diff --git a/setup.py b/setup.py index 1846f7bf97b5d4..efbe239df08ab3 100644 --- a/setup.py +++ b/setup.py @@ -96,7 +96,7 @@ # 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py _deps = [ "Pillow>=10.0.1,<=15.0", - "accelerate>=0.26.0", + "accelerate>=1.0.0", "av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream. "beautifulsoup4", "blobfile", @@ -183,6 +183,7 @@ "tokenizers>=0.20,<0.21", "torch", "torchaudio", + "torchdata>=0.8.0", "torchvision", "pyctcdecode>=0.4.0", "tqdm>=4.27", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 5ce23f4b7647d5..988dd79c6a264d 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -3,7 +3,7 @@ # 2. run `make deps_table_update`` deps = { "Pillow": "Pillow>=10.0.1,<=15.0", - "accelerate": "accelerate>=0.26.0", + "accelerate": "accelerate>=1.0.0", "av": "av==9.2.0", "beautifulsoup4": "beautifulsoup4", "blobfile": "blobfile", @@ -88,6 +88,7 @@ "tokenizers": "tokenizers>=0.20,<0.21", "torch": "torch", "torchaudio": "torchaudio", + "torchdata": "torchdata>=0.8.0", "torchvision": "torchvision", "pyctcdecode": "pyctcdecode>=0.4.0", "tqdm": "tqdm>=4.27", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7bb2d5049dccf8..72f2c4dc6ef0f7 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -55,6 +55,7 @@ from .utils import ( ACCELERATE_MIN_VERSION, GGUF_MIN_VERSION, + TORCHDATA_MIN_VERSION, is_accelerate_available, is_apex_available, is_aqlm_available, @@ -135,6 +136,7 @@ is_torch_xpu_available, is_torchao_available, is_torchaudio_available, + is_torchdata_available, is_torchdynamo_available, is_torchvision_available, is_vision_available, @@ -957,6 +959,15 @@ def require_torch_multi_xpu(test_case): jax_device = None +def require_torchdata(test_case, min_version: str = TORCHDATA_MIN_VERSION): + """ + Decorator marking a test that requires torchdata. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless( + is_torchdata_available(min_version), f"test requires is_torchdata_available version >= {min_version}" + )(test_case) + + def require_torchdynamo(test_case): """Decorator marking a test that requires TorchDynamo""" return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 58a20f66f4e81b..bcea313ffd6e41 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2086,6 +2086,12 @@ def train( state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) if state.train_batch_size is not None: self._train_batch_size = state.train_batch_size + if state.train_dataloader_state_dict is None and self.use_stateful_dataloader: + raise ValueError( + "TrainerArguments.accelerator_config.use_stateful_dataloader is true, however a checkpoint with" + "no saved dataloader state_dict has been loaded. If this is the correct checkpoint to be loaded," + "please set `accelerator_config.use_stateful_dataloader` to false." + ) # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: @@ -2396,7 +2402,10 @@ def _inner_training_loop( rng_to_sync = False steps_skipped = 0 if steps_trained_in_current_epoch > 0: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + if self.use_stateful_dataloader: + epoch_dataloader.load_state_dict(self.state.train_dataloader_state_dict) + else: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True @@ -2526,10 +2535,12 @@ def _inner_training_loop( # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() - model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + # Maybe we should only update the state dict right before checkpointing? + if self.use_stateful_dataloader: + self.state.train_dataloader_state_dict = epoch_dataloader.state_dict() self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: @@ -4930,6 +4941,15 @@ def create_accelerator_and_postprocess(self): # this would have been updated above, no need for it anymore accelerator_config.pop("gradient_accumulation_kwargs") + use_stateful_dataloader = accelerator_config.pop("use_stateful_dataloader") + if use_stateful_dataloader: + if not is_accelerate_available("1.0.0"): + raise ImportError( + "`use_stateful_dataloader` is only supported in accelerate v1.0.0 and above. Please upgrade accelerate to use this feature." + ) + dataloader_config.use_stateful_dataloader = use_stateful_dataloader + self.use_stateful_dataloader = use_stateful_dataloader + args = { "deepspeed_plugin": self.args.deepspeed_plugin, "gradient_accumulation_plugin": gradient_accumulation_plugin, diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 405874acf8f4c4..2bec536770cf5d 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -19,7 +19,7 @@ import dataclasses import json from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm @@ -87,6 +87,9 @@ class TrainerState: stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*): Callbacks attached to the `Trainer` that should have their states be saved or restored. Relevent callbacks should implement a `state` and `from_state` function. + train_dataloader_state_dict (`Dict[str, Any]`, *optional*): + Present if the trainer is using a stateful dataloader to serve training data. + State dict which tracks the inner state of a training dataset StatefulDataLoader. """ epoch: Optional[float] = None @@ -108,6 +111,7 @@ class TrainerState: trial_name: str = None trial_params: Dict[str, Union[str, float, int, bool]] = None stateful_callbacks: List["TrainerCallback"] = None + train_dataloader_state_dict: Dict[str, Any] = None def __post_init__(self): if self.log_history is None: diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5f78860fe6c115..f39cfaa3995585 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1250,6 +1250,10 @@ class AcceleratorConfig: Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning. + use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether or not to have the dataloaders prepared by the Accelerator be backed by + `[torchdata.StatefulDataLoader]`(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). + This requires `accelerate` version 1.0.0 or higher, and `torchdata` version 0.8.0 to be installed." """ @@ -1320,6 +1324,13 @@ class AcceleratorConfig: }, ) + use_stateful_dataloader: bool = field( + default=False, + metadata={ + "help": "Whether or not to have the dataloaders prepared by the Accelerator be backed by`[torchdata.StatefulDataLoader]`(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `accelerate` version 1.0.0 or higher, and `torchdata` version 0.8.0 to be installed." + }, + ) + @classmethod def from_json_file(cls, json_file): # Check if exists diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 485610dd9baa28..ecc6af8eee1721 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -602,6 +602,10 @@ class TrainingArguments: Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues with hyperparameter tuning. + - use_stateful_dataloader (`bool`, *optional*, defaults to `False`): + Whether or not to have the dataloaders prepared by the Accelerator be backed by + `[torchdata.StatefulDataLoader]`(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). + This requires `accelerate` version 1.0.0 or higher, and `torchdata` version 0.8.0 to be installed." label_smoothing_factor (`float`, *optional*, defaults to 0.0): The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2876eef9ea02df..837fffd8b416cb 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -104,6 +104,7 @@ ENV_VARS_TRUE_VALUES, GGUF_MIN_VERSION, TORCH_FX_REQUIRED_VERSION, + TORCHDATA_MIN_VERSION, USE_JAX, USE_TF, USE_TORCH, @@ -221,6 +222,7 @@ is_torch_xpu_available, is_torchao_available, is_torchaudio_available, + is_torchdata_available, is_torchdistx_available, is_torchdynamo_available, is_torchdynamo_compiling, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 2f0cfe1d6dcec8..521908a142be14 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -91,6 +91,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ACCELERATE_MIN_VERSION = "0.26.0" FSDP_MIN_VERSION = "1.12.0" GGUF_MIN_VERSION = "0.10.0" +TORCHDATA_MIN_VERSION = "0.8.0" XLA_FSDPV2_MIN_VERSION = "2.2.0" HQQ_MIN_VERSION = "0.2.1" @@ -106,6 +107,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") _schedulefree_available = _is_package_available("schedulefree") +_torchdata_available, _torchdata_version = _is_package_available("torchdata", return_version=True) # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -318,6 +320,10 @@ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION): return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) +def is_torchdata_available(min_version: str = TORCHDATA_MIN_VERSION): + return _torchdata_available and version.parse(_torchdata_version) >= version.parse(min_version) + + def is_torch_deterministic(): """ Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5c03355785d2b5..71146a789b5f9e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -93,6 +93,7 @@ require_torch_tensorrt_fx, require_torch_tf32, require_torch_up_to_2_accelerators, + require_torchdata, require_torchdynamo, require_vision, require_wandb, @@ -137,6 +138,7 @@ Trainer, TrainerState, ) + from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer_pt_utils import AcceleratorConfig if is_safetensors_available(): @@ -146,6 +148,7 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") +require_accelerate_version_min_1_0_0 = partial(require_accelerate, min_version="1.0.0") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") if is_accelerate_available(): from accelerate import Accelerator @@ -1436,6 +1439,57 @@ def test_train_and_eval_dataloaders(self): new_eval_dataset = RegressionDataset(length=128) self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) + @require_accelerate_version_min_1_0_0 + @require_torchdata + def test_train_and_eval_dataloaders_with_use_stateful_dataloader(self): + """Identical to `test_train_and_eval_dataloaders`, however with an AcceleratorConfig that sets `use_stateful_dataloader=True` + (Note: Is a separate test instead of parameterized due to dependencies.) + """ + accelerator_config = AcceleratorConfig(use_stateful_dataloader=True) + if torch_device == "cuda": + n_gpu = max(1, backend_device_count(torch_device)) + else: + n_gpu = 1 + trainer = get_regression_trainer( + learning_rate=0.1, per_device_train_batch_size=16, accelerator_config=accelerator_config + ) + self.assertTrue(trainer.use_stateful_dataloader) + self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu) + trainer = get_regression_trainer( + learning_rate=0.1, per_device_eval_batch_size=16, accelerator_config=accelerator_config + ) + self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16 * n_gpu) + + # Check drop_last works + trainer = get_regression_trainer( + train_len=66, + eval_len=74, + learning_rate=0.1, + per_device_train_batch_size=16, + per_device_eval_batch_size=32, + accelerator_config=accelerator_config, + ) + self.assertTrue(trainer.use_stateful_dataloader) + self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu) + 1) + self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu) + 1) + + trainer = get_regression_trainer( + train_len=66, + eval_len=74, + learning_rate=0.1, + per_device_train_batch_size=16, + per_device_eval_batch_size=32, + dataloader_drop_last=True, + accelerator_config=accelerator_config, + ) + self.assertTrue(trainer.use_stateful_dataloader) + self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu)) + self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu)) + + # Check passing a new dataset for evaluation works + new_eval_dataset = RegressionDataset(length=128) + self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) + # tests that we do not require dataloader to have a .dataset attribute def test_dataloader_without_dataset(self): train_dataset = RegressionDataset(length=128) @@ -2710,6 +2764,133 @@ def test_resume_training_with_randomness(self): self.assertAlmostEqual(a, a1, delta=1e-5) self.assertAlmostEqual(b, b1, delta=1e-5) + @require_accelerate_version_min_1_0_0 + @require_torchdata + @require_torch_up_to_2_accelerators + def test_resume_training_with_stateful_dataloaders(self): + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). + accelerator_config = AcceleratorConfig(use_stateful_dataloader=True) + with tempfile.TemporaryDirectory() as tmpdir: + kwargs = { + "output_dir": tmpdir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "logging_steps": 5, + "accelerator_config": accelerator_config, + } + trainer = get_regression_trainer(**kwargs) + self.assertTrue(trainer.use_stateful_dataloader) + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + self.assertIsNotNone(trainer.state.train_dataloader_state_dict) + state = dataclasses.asdict(trainer.state) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + # Assert the checkpoint has a saved state_dict. + checkpoint_5_state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME)) + self.assertIsNotNone(checkpoint_5_state.train_dataloader_state_dict) + + # Reinitialize trainer + trainer = get_regression_trainer(**kwargs) + + trainer.train(resume_from_checkpoint=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + self.assertIsNotNone(trainer.state.train_dataloader_state_dict) + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + + # Now check with a later checkpoint that it also works when we span over one epoch + checkpoint = os.path.join(tmpdir, "checkpoint-15") + + # Assert the checkpoint has a saved state_dict. + checkpoint_15_state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME)) + self.assertIsNotNone(checkpoint_15_state.train_dataloader_state_dict) + + # Reinitialize trainer and load model + trainer = get_regression_trainer(**kwargs) + + trainer.train(resume_from_checkpoint=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + self.assertIsNotNone(trainer.state.train_dataloader_state_dict) + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + + # With a regular model that is not a PreTrainedModel + with tempfile.TemporaryDirectory() as tmpdir: + kwargs = { + "output_dir": tmpdir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "pretrained": False, + "accelerator_config": accelerator_config, + } + + trainer = get_regression_trainer(**kwargs) + self.assertTrue(trainer.use_stateful_dataloader) + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + self.assertIsNotNone(trainer.state.train_dataloader_state_dict) + state = dataclasses.asdict(trainer.state) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + # Assert the checkpoint has a saved state_dict. + checkpoint_5_state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME)) + self.assertIsNotNone(checkpoint_5_state.train_dataloader_state_dict) + + # Reinitialize trainer and load model + trainer = get_regression_trainer(**kwargs) + + trainer.train(resume_from_checkpoint=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + self.assertIsNotNone(trainer.state.train_dataloader_state_dict) + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + + # Now check with a later checkpoint that it also works when we span over one epoch + checkpoint = os.path.join(tmpdir, "checkpoint-15") + + # Assert the checkpoint has a saved state_dict. + checkpoint_15_state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME)) + self.assertIsNotNone(checkpoint_15_state.train_dataloader_state_dict) + + # Reinitialize trainer and load model + trainer = get_regression_trainer(**kwargs) + + trainer.train(resume_from_checkpoint=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + self.assertIsNotNone(trainer.state.train_dataloader_state_dict) + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + + # Now check failures + + # 1. fail to find a bogus checkpoint + trainer = get_regression_trainer(accelerator_config=accelerator_config) + with self.assertRaises(Exception) as context: + trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus") + self.assertTrue("Can't find a valid checkpoint at" in str(context.exception)) + + # 2. fail to find any checkpoint - due a fresh output_dir + output_dir2 = self.get_auto_remove_tmp_dir() + trainer = get_regression_trainer(output_dir=output_dir2, accelerator_config=accelerator_config) + with self.assertRaises(Exception) as context: + trainer.train(resume_from_checkpoint=True) + self.assertTrue("No valid checkpoint found in output directory" in str(context.exception)) + @slow @require_accelerate @require_torch_non_multi_accelerator @@ -3643,6 +3824,7 @@ def test_accelerator_config_empty(self): self.assertEqual(trainer.accelerator.dispatch_batches, None) self.assertEqual(trainer.accelerator.even_batches, True) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + self.assertEqual(trainer.accelerator.use_stateful_dataloader, False) if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: # gradient accumulation kwargs configures gradient_state @@ -3675,6 +3857,7 @@ def test_accelerator_config_from_dict(self): self.assertEqual(trainer.accelerator.dispatch_batches, True) self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + self.assertEqual(trainer.accelerator.use_stateful_dataloader, False) if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) @@ -3705,6 +3888,7 @@ def test_accelerator_config_from_yaml(self): self.assertEqual(trainer.accelerator.dispatch_batches, True) self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False) + self.assertEqual(trainer.accelerator.use_stateful_dataloader, False) if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) @@ -3729,6 +3913,7 @@ def test_accelerator_config_from_dataclass(self): self.assertEqual(trainer.accelerator.dispatch_batches, True) self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False) + self.assertEqual(trainer.accelerator.use_stateful_dataloader, False) @require_accelerate_version_min_0_28 def test_accelerate_config_from_dataclass_grad_accum(self): @@ -3779,6 +3964,7 @@ def test_accelerator_config_from_partial(self): self.assertEqual(trainer.accelerator.dispatch_batches, None) self.assertEqual(trainer.accelerator.even_batches, True) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + self.assertEqual(trainer.accelerator.use_stateful_dataloader, False) def test_accelerator_config_from_dict_with_deprecated_args(self): # Checks that accelerator kwargs can be passed through