Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AcceleratorConfig.use_stateful_dataloader in Trainer #34205

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow>=10.0.1,<=15.0",
"accelerate>=0.26.0",
"accelerate>=1.0.0",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"blobfile",
Expand Down Expand Up @@ -183,6 +183,7 @@
"tokenizers>=0.20,<0.21",
"torch",
"torchaudio",
"torchdata>=0.8.0",
"torchvision",
"pyctcdecode>=0.4.0",
"tqdm>=4.27",
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow>=10.0.1,<=15.0",
"accelerate": "accelerate>=0.26.0",
"accelerate": "accelerate>=1.0.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"blobfile": "blobfile",
Expand Down Expand Up @@ -88,6 +88,7 @@
"tokenizers": "tokenizers>=0.20,<0.21",
"torch": "torch",
"torchaudio": "torchaudio",
"torchdata": "torchdata>=0.8.0",
"torchvision": "torchvision",
"pyctcdecode": "pyctcdecode>=0.4.0",
"tqdm": "tqdm>=4.27",
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .utils import (
ACCELERATE_MIN_VERSION,
GGUF_MIN_VERSION,
TORCHDATA_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
Expand Down Expand Up @@ -135,6 +136,7 @@
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdata_available,
is_torchdynamo_available,
is_torchvision_available,
is_vision_available,
Expand Down Expand Up @@ -957,6 +959,15 @@ def require_torch_multi_xpu(test_case):
jax_device = None


def require_torchdata(test_case, min_version: str = TORCHDATA_MIN_VERSION):
"""
Decorator marking a test that requires torchdata. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(
is_torchdata_available(min_version), f"test requires is_torchdata_available version >= {min_version}"
)(test_case)


def require_torchdynamo(test_case):
"""Decorator marking a test that requires TorchDynamo"""
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
Expand Down
24 changes: 22 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,12 @@ def train(
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
if state.train_batch_size is not None:
self._train_batch_size = state.train_batch_size
if state.train_dataloader_state_dict is None and self.use_stateful_dataloader:
raise ValueError(
"TrainerArguments.accelerator_config.use_stateful_dataloader is true, however a checkpoint with"
"no saved dataloader state_dict has been loaded. If this is the correct checkpoint to be loaded,"
"please set `accelerator_config.use_stateful_dataloader` to false."
)

# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
Expand Down Expand Up @@ -2396,7 +2402,10 @@ def _inner_training_loop(
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
if self.use_stateful_dataloader:
epoch_dataloader.load_state_dict(self.state.train_dataloader_state_dict)
else:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
Expand Down Expand Up @@ -2526,10 +2535,12 @@ def _inner_training_loop(
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()

model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
# Maybe we should only update the state dict right before checkpointing?
if self.use_stateful_dataloader:
self.state.train_dataloader_state_dict = epoch_dataloader.state_dict()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
Expand Down Expand Up @@ -4930,6 +4941,15 @@ def create_accelerator_and_postprocess(self):
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")

use_stateful_dataloader = accelerator_config.pop("use_stateful_dataloader")
if use_stateful_dataloader:
if not is_accelerate_available("1.0.0"):
raise ImportError(
"`use_stateful_dataloader` is only supported in accelerate v1.0.0 and above. Please upgrade accelerate to use this feature."
)
dataloader_config.use_stateful_dataloader = use_stateful_dataloader
self.use_stateful_dataloader = use_stateful_dataloader

args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -87,6 +87,9 @@ class TrainerState:
stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
Callbacks attached to the `Trainer` that should have their states be saved or restored.
Relevent callbacks should implement a `state` and `from_state` function.
train_dataloader_state_dict (`Dict[str, Any]`, *optional*):
Present if the trainer is using a stateful dataloader to serve training data.
State dict which tracks the inner state of a training dataset StatefulDataLoader.
"""

epoch: Optional[float] = None
Expand All @@ -108,6 +111,7 @@ class TrainerState:
trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None
stateful_callbacks: List["TrainerCallback"] = None
train_dataloader_state_dict: Dict[str, Any] = None

def __post_init__(self):
if self.log_history is None:
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ class AcceleratorConfig:
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
must be initialized. May lead to issues using sweeps or hyperparameter tuning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether or not to have the dataloaders prepared by the Accelerator be backed by
`[torchdata.StatefulDataLoader]`(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `accelerate` version 1.0.0 or higher, and `torchdata` version 0.8.0 to be installed."

"""

Expand Down Expand Up @@ -1320,6 +1324,13 @@ class AcceleratorConfig:
},
)

use_stateful_dataloader: bool = field(
default=False,
metadata={
"help": "Whether or not to have the dataloaders prepared by the Accelerator be backed by`[torchdata.StatefulDataLoader]`(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `accelerate` version 1.0.0 or higher, and `torchdata` version 0.8.0 to be installed."
},
)

@classmethod
def from_json_file(cls, json_file):
# Check if exists
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,10 @@ class TrainingArguments:
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.
- use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether or not to have the dataloaders prepared by the Accelerator be backed by
`[torchdata.StatefulDataLoader]`(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `accelerate` version 1.0.0 or higher, and `torchdata` version 0.8.0 to be installed."

label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
ENV_VARS_TRUE_VALUES,
GGUF_MIN_VERSION,
TORCH_FX_REQUIRED_VERSION,
TORCHDATA_MIN_VERSION,
USE_JAX,
USE_TF,
USE_TORCH,
Expand Down Expand Up @@ -221,6 +222,7 @@
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdata_available,
is_torchdistx_available,
is_torchdynamo_available,
is_torchdynamo_compiling,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
ACCELERATE_MIN_VERSION = "0.26.0"
FSDP_MIN_VERSION = "1.12.0"
GGUF_MIN_VERSION = "0.10.0"
TORCHDATA_MIN_VERSION = "0.8.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"

Expand All @@ -106,6 +107,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available = _is_package_available("schedulefree")
_torchdata_available, _torchdata_version = _is_package_available("torchdata", return_version=True)
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -318,6 +320,10 @@ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)


def is_torchdata_available(min_version: str = TORCHDATA_MIN_VERSION):
return _torchdata_available and version.parse(_torchdata_version) >= version.parse(min_version)


def is_torch_deterministic():
"""
Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
Expand Down
Loading