Skip to content

Commit

Permalink
Raise error if max duration is in epochs and dataloader is infinite (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Feb 4, 2023
1 parent 7d45bc2 commit bb856ad
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 12 deletions.
12 changes: 9 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

__all__ = ['Trainer']

# syntax to shorten the Scheduler type annoations
# syntax to shorten the Scheduler type annotations
Scheduler = Union[ComposerScheduler, PyTorchScheduler]


Expand Down Expand Up @@ -622,7 +622,7 @@ class Trainer:
If ``None`` then no checkpoint will be loaded. (default: ``None``)
load_object_store (Union[ObjectStore, LoggerDestination], optional): If the ``load_path`` is in an
object store (i.e. AWS S3 or Google Cloud Storage), an instance of :class:`.ObjectStore` or
:class:`.LoggerDestination` which will be used to retreive the checkpoint. Otherwise, if the
:class:`.LoggerDestination` which will be used to retrieve the checkpoint. Otherwise, if the
checkpoint is a local filepath, set to ``None``. Also, it can be ``None`` if the ``load_path`` is
an S3 URI because the appropriate object store will be automatically constructed in that case.
Ignored if ``load_path`` is ``None``.
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def __init__(
optimizers = map_collection(optimizers, device.optimizer_to_device)

# Microbatching
# To support backwards compatability, we currently support both device_train_microbatch_size
# To support backwards compatibility, we currently support both device_train_microbatch_size
# and grad_accum. If both are specified with grad_accum=1, we will use device_train_microbatch_size.
if device_train_microbatch_size is not None:
using_device_microbatch_size = True
Expand Down Expand Up @@ -1689,6 +1689,12 @@ def fit(
if self.state.max_duration is None:
_raise_missing_argument_exception('max_duration')

if self.state.dataloader_len is None and self.state.max_duration.unit == TimeUnit.EPOCH:
raise ValueError(
('max_duration cannot be specified in epochs when using an infinite dataloader. Please either '
'provide a dataloader with a length, specify max_duration in batches, samples, or tokens, or provide '
'train_subset_num_batches.'))

if self.state.max_duration <= self.state.timestamp.get(self.state.max_duration.unit) and not reset_time:
raise ValueError(
(f'The max_duration ({self.state.max_duration}) is less than or equal to the elapsed training duration '
Expand Down
27 changes: 21 additions & 6 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import List, Type

from tests.common.compare import deep_compare
from tests.common.datasets import (RandomClassificationDataset, RandomImageDataset, RandomSegmentationDataset,
RandomTextClassificationDataset, SimpleDataset)
from tests.common.datasets import (InfiniteClassificationDataset, RandomClassificationDataset, RandomImageDataset,
RandomSegmentationDataset, RandomTextClassificationDataset, SimpleDataset)
from tests.common.events import EventCounterCallback
from tests.common.markers import device, world_size
from tests.common.models import (ConvModel, EmbeddedWeightTiedModel, SimpleConvModel, SimpleModel,
Expand All @@ -20,8 +20,23 @@ def get_module_subclasses(module: types.ModuleType, cls: Type) -> List[Type]:


__all__ = [
'assert_state_equivalent', 'RandomClassificationDataset', 'RandomTextClassificationDataset', 'RandomImageDataset',
'RandomSegmentationDataset', 'ConvModel', 'SimpleConvModel', 'SimpleModel', 'SimpleTransformerClassifier',
'EmbeddedWeightTiedModel', 'SimpleWeightTiedModel', 'EventCounterCallback', 'deep_compare', 'device', 'world_size',
'get_module_subclasses', 'SimpleModelWithDropout', 'SimpleDataset'
'assert_state_equivalent',
'RandomClassificationDataset',
'RandomTextClassificationDataset',
'RandomImageDataset',
'RandomSegmentationDataset',
'ConvModel',
'SimpleConvModel',
'SimpleModel',
'SimpleTransformerClassifier',
'EmbeddedWeightTiedModel',
'SimpleWeightTiedModel',
'EventCounterCallback',
'deep_compare',
'device',
'world_size',
'get_module_subclasses',
'SimpleModelWithDropout',
'SimpleDataset',
'InfiniteClassificationDataset',
]
19 changes: 18 additions & 1 deletion tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,30 @@
import pytest
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision.datasets import VisionDataset

from composer.utils import dist
from tests.common.models import configure_tiny_bert_tokenizer, configure_tiny_gpt2_tokenizer


class InfiniteClassificationDataset(IterableDataset):
"""Classification dataset that never ends.
Args:
shape (Sequence[int]): shape of features (default: (1, 1, 1))
num_classes (int): number of classes (default: 2)
"""

def __init__(self, shape: Sequence[int] = (1, 1, 1), num_classes: int = 2):
self.shape = shape
self.num_classes = num_classes

def __iter__(self):
while True:
yield torch.randn(*self.shape), torch.randint(0, self.num_classes, size=(1,))[0]


class RandomClassificationDataset(Dataset):
"""Classification dataset drawn from a normal distribution.
Expand Down
22 changes: 20 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from composer.optim import ExponentialScheduler
from composer.trainer.trainer import _generate_run_name
from composer.utils import dist, is_model_deepspeed, is_model_fsdp, map_collection, reproducibility
from tests.common import (RandomClassificationDataset, RandomImageDataset, SimpleConvModel, SimpleModel, device,
world_size)
from tests.common import (InfiniteClassificationDataset, RandomClassificationDataset, RandomImageDataset,
SimpleConvModel, SimpleModel, device, world_size)
from tests.common.events import EventCounterCallback
from tests.test_state import assert_state_equivalent

Expand Down Expand Up @@ -186,6 +186,24 @@ def test_max_duration(
# Assert that the states are equivalent
assert_state_equivalent(init_trainer.state, fit_trainer.state)

@pytest.mark.parametrize('max_duration', [1, '1ep', '1ba', '1sp'])
@pytest.mark.parametrize('train_subset_num_batches', [-1, 1])
def test_infinite_train_loader(self, model: ComposerModel, max_duration: Union[int, str],
train_subset_num_batches: int):
should_raise = (isinstance(max_duration, int) or
max_duration.endswith('ep')) and (train_subset_num_batches is None or
train_subset_num_batches == -1)
context = pytest.raises(
ValueError,
match='max_duration cannot be specified in epochs') if should_raise else contextlib.nullcontext()
with context:
train_loader = DataLoader(InfiniteClassificationDataset(), batch_size=4)
trainer = Trainer(model=model,
train_dataloader=train_loader,
max_duration=max_duration,
train_subset_num_batches=train_subset_num_batches)
trainer.fit()

@pytest.mark.parametrize('reset_time', [True, False])
@pytest.mark.parametrize('new_duration', [
Time.from_timestring('1ep'),
Expand Down

0 comments on commit bb856ad

Please sign in to comment.