Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update wallclock logging to default hours #2005

Merged
merged 4 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions composer/callbacks/runtime_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,29 @@ class RuntimeEstimator(Callback):
skip_batches (int, optional): Number of batches to skip before starting clock to estimate
remaining time. Typically, the first few batches are slower due to dataloader, cache
warming, and other reasons. Defaults to 1.
time_unit (str, optional): Time unit to use for `wall_clock` logging. Can be one of
'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
"""

def __init__(self, skip_batches: int = 1) -> None:
def __init__(self, skip_batches: int = 1, time_unit: str = 'hours') -> None:
self._enabled = True
self.batches_left_to_skip = skip_batches
self.start_time = None
self.start_dur = None

self.divider = 1
if time_unit == 'seconds':
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
self.divider = 1
elif time_unit == 'minutes':
self.divider = 60
elif time_unit == 'hours':
self.divider = 60 * 60
elif time_unit == 'days':
self.divider = 60 * 60 * 24
else:
raise ValueError(
f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".')

# Keep track of time spent evaluating
self.total_eval_wct = 0.0
self.eval_wct_per_label: Dict[str, List[float]] = {}
Expand Down Expand Up @@ -140,7 +155,7 @@ def batch_end(self, state: State, logger: Logger) -> None:
remaining_calls = num_total_evals - num_evals_finished
remaining_time += eval_wct_avg * remaining_calls

logger.log_metrics({'wall_clock/remaining_estimate': remaining_time})
logger.log_metrics({'wall_clock/remaining_estimate': remaining_time / self.divider})

def eval_end(self, state: State, logger: Logger) -> None:
# If eval is called before training starts, ignore it
Expand Down
31 changes: 27 additions & 4 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,37 @@ class SpeedMonitor(Callback):
Args:
window_size (int, optional): Number of batches to use for a rolling average of throughput.
Defaults to 100.
gpu_flops_available (float, optional): Number of flops available on the GPU. If not set,
SpeedMonitor will attempt to determine this automatically. Defaults to None.
time_unit (str, optional): Time unit to use for `wall_clock` logging. Can be one of
'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
"""

def __init__(self, window_size: int = 100, gpu_flops_available: Optional[Union[float, int]] = None):
def __init__(
self,
window_size: int = 100,
gpu_flops_available: Optional[Union[float, int]] = None,
time_unit: str = 'hours',
):
# Track the batch num samples and wct to compute throughput over a window of batches
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)

self.gpu_flops_available = gpu_flops_available

self.divider = 1
if time_unit == 'seconds':
self.divider = 1
elif time_unit == 'minutes':
self.divider = 60
elif time_unit == 'hours':
self.divider = 60 * 60
elif time_unit == 'days':
self.divider = 60 * 60 * 24
else:
raise ValueError(
f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".')

# Keep track of time spent evaluating
self.total_eval_wct = 0.0

Expand Down Expand Up @@ -281,10 +303,11 @@ def batch_end(self, state: State, logger: Logger):

# Log the time
# `state.timestamp` excludes any time spent in evaluation
train_wct = state.timestamp.total_wct.total_seconds()
logger.log_metrics({
'wall_clock/train': state.timestamp.total_wct.total_seconds(),
'wall_clock/val': self.total_eval_wct,
'wall_clock/total': state.timestamp.total_wct.total_seconds() + self.total_eval_wct,
'wall_clock/train': train_wct / self.divider,
'wall_clock/val': self.total_eval_wct / self.divider,
'wall_clock/total': (train_wct + self.total_eval_wct) / self.divider,
})

def eval_end(self, state: State, logger: Logger):
Expand Down
37 changes: 32 additions & 5 deletions tests/callbacks/test_runtime_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import datetime
import time

import pytest
from torch.utils.data import DataLoader

from composer.callbacks import RuntimeEstimator
Expand All @@ -23,23 +25,33 @@ def _assert_no_negative_values(logged_values):
assert v >= 0


def test_runtime_estimator():
@pytest.mark.parametrize('time_unit', ['seconds', 'minutes', 'hours', 'days'])
def test_runtime_estimator(time_unit: str):
# Construct the callbacks
skip_batches = 1
runtime_estimator = RuntimeEstimator(skip_batches=skip_batches)
runtime_estimator = RuntimeEstimator(skip_batches=skip_batches, time_unit=time_unit)
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger

simple_model = SimpleModel()
original_fwd = simple_model.forward

def new_fwd(x):
time.sleep(0.02)
return original_fwd(x)

simple_model.forward = new_fwd # type: ignore

# Construct the trainer and train
trainer = Trainer(
model=SimpleModel(),
model=simple_model,
callbacks=runtime_estimator,
loggers=in_memory_logger,
train_dataloader=DataLoader(RandomClassificationDataset()),
eval_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ep',
eval_interval='1ep',
train_subset_num_batches=10,
eval_subset_num_batches=10,
train_subset_num_batches=5,
eval_subset_num_batches=5,
)
trainer.fit()

Expand All @@ -48,3 +60,18 @@ def test_runtime_estimator():

expected_calls = int(trainer.state.timestamp.batch) - skip_batches
assert wall_clock_remaining_calls == expected_calls

ba_2_estimate = in_memory_logger.data['wall_clock/remaining_estimate'][1][-1]
# Should be ~0.2 seconds
if time_unit == 'seconds':
assert ba_2_estimate < 1
assert ba_2_estimate > 0.1
elif time_unit == 'minutes':
assert ba_2_estimate < 1 / 60
assert ba_2_estimate > 0.1 / 60
elif time_unit == 'hours':
assert ba_2_estimate < 1 / 60 / 60
assert ba_2_estimate > 0.1 / 60 / 60
elif time_unit == 'days':
assert ba_2_estimate < 1 / 60 / 60 / 24
assert ba_2_estimate > 0.1 / 60 / 60 / 24