Skip to content

Commit

Permalink
Remove skip_first_batches support for StatefulDataloader and fix al…
Browse files Browse the repository at this point in the history
…l the tests (#3068)

* Pippy tests - good

* Fix dataloader example tests

* SD issue

* Rm test

* Docs

* Rm from doc
  • Loading branch information
muellerzr authored Sep 2, 2024
1 parent a848592 commit 8931e5e
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 83 deletions.
21 changes: 6 additions & 15 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,11 +1164,11 @@ def prepare_data_loader(
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
Should not be used if the original dataloader is a `StatefulDataLoader`.
"""

def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.sampler = batch_sampler.sampler
self.skip_batches = skip_batches

def __iter__(self):
Expand All @@ -1186,15 +1186,14 @@ def __len__(self):

class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""
Expand All @@ -1215,11 +1214,9 @@ def __iter__(self):

def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
the original dataloader is a `StatefulDataLoader`.
"""
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

state = PartialState()
if state.distributed_type == DistributedType.XLA:
device = dataloader.device
Expand Down Expand Up @@ -1263,7 +1260,6 @@ def skip_first_batches(dataloader, num_batches=0):
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
Expand All @@ -1280,17 +1276,12 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
)
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)

Expand Down
106 changes: 46 additions & 60 deletions src/accelerate/test_utils/scripts/external_deps/test_pippy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.models import resnet34
from transformers import (
BertConfig,
BertForMaskedLM,
GPT2Config,
GPT2ForSequenceClassification,
T5Config,
T5ForConditionalGeneration,
)

from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.utils import DistributedType, send_to_device, set_seed
from accelerate.utils import DistributedType, set_seed


model_to_config = {
"t5": (T5ForConditionalGeneration, T5Config, 1024),
"bert": (BertForMaskedLM, BertConfig, 512),
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}
Expand All @@ -42,23 +38,19 @@ def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
# config_args["pad_token_id"] = 0
model_config = config(**config_args)
model = initializer(model_config)
return model, torch.randint(
low=0,
high=model_config.vocab_size,
size=(num_processes, seq_len),
device=device,
dtype=torch.int64,
requires_grad=False,
)
kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
trace_input = torch.randint(size=(1, seq_len), **kwargs)
inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
return model, trace_input, inference_inputs


def test_gpt2(batch_size: int = 2):
def test_bert(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inputs.to("cuda")
inputs = inference_inputs.to("cuda")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
Expand All @@ -68,63 +60,57 @@ def test_gpt2(batch_size: int = 2):
assert output is not None, "Output was not generated in the last process!"


def test_t5(batch_size: int = 2):
def test_gpt2(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size)
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
model = prepare_pippy(
model,
no_split_module_classes=model._no_split_modules,
example_kwargs=example_inputs,
)
model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = send_to_device(example_inputs, "cuda:0")
inputs = inference_inputs.to("cuda")
with torch.no_grad():
output = model(*inputs.values())
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"


def test_resnet(batch_size: int = 2):
set_seed(42)
state = PartialState()
model = resnet34()
input_tensor = torch.rand(batch_size, 3, 224, 224)
model = prepare_pippy(
model,
example_args=(input_tensor,),
)
inputs = send_to_device(input_tensor, "cuda:0")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
# def test_resnet(batch_size: int = 2):
# set_seed(42)
# state = PartialState()
# model = resnet34()
# input_tensor = torch.rand(1, 3, 224, 224)
# model = prepare_pippy(
# model,
# example_args=(input_tensor,),
# )
# inference_inputs = torch.rand(batch_size, 3, 224, 224)
# inputs = send_to_device(inference_inputs, "cuda:0")
# with torch.no_grad():
# output = model(inputs)
# # Zach: Check that we just grab the real outputs we need at the end
# if not state.is_last_process:
# assert output is None, "Output was not generated on just the last process!"
# else:
# assert output is not None, "Output was not generated in the last process!"


if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing T5...")
test_t5()
test_t5(1)
test_t5(3)
state.print("Testing CV model...")
test_resnet()
test_resnet(3)
try:
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing BERT...")
test_bert()
else:
print("Less than two GPUs found, not running tests!")
finally:
state.destroy_process_group()
else:
print("Less than two GPUs found, not running tests!")
7 changes: 0 additions & 7 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,6 @@ def test_skip_data_loader(self):
assert isinstance(dataloader, StatefulDataLoader)
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]

@require_torchdata_stateful_dataloader
def test_skip_first_batches(self):
dataloader = StatefulDataLoader(list(range(16)), batch_size=4)
new_dataloader = skip_first_batches(dataloader, num_batches=2)
assert isinstance(new_dataloader, StatefulDataLoader)
assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]

@require_torchdata_stateful_dataloader
def test_end_of_dataloader(self):
dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from unittest import mock, skip

import torch

Expand Down Expand Up @@ -261,6 +261,9 @@ def test_ddp_comm_hook(self):
testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
run_command(self.launch_args + testargs)

@skip(
reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
)
@require_multi_device
def test_distributed_inference_examples_stable_diffusion(self):
testargs = ["examples/inference/distributed/stable_diffusion.py"]
Expand Down

0 comments on commit 8931e5e

Please sign in to comment.