Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch Speed Monitor MFU #2013

Merged
48 changes: 32 additions & 16 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
# Track the batch num samples and wct to compute throughput over a window of batches
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
self.history_flops: Deque[float] = deque(maxlen=window_size + 1)

self.gpu_flops_available = gpu_flops_available

Expand Down Expand Up @@ -284,22 +285,37 @@ def batch_end(self, state: State, logger: Logger):
except AttributeError:
pass

composer_model = state.model
if not isinstance(composer_model, ComposerModel):
composer_model = composer_model.module # Pass through DDP wrapping
if hasattr(composer_model, 'flops_per_batch'):
model_flops_per_batch = composer_model.flops_per_batch # type: ignore
if not isinstance(model_flops_per_batch, Callable):
raise TypeError('flops_per_batch must a callable accepting a batch and '
f'returning an int or float. Instead, got {type(model_flops_per_batch)}.')
flops_per_batch = model_flops_per_batch(state.batch)
flops_per_sec = flops_per_batch * batches_per_sec
logger.log_metrics({'throughput/flops_per_sec': flops_per_sec})
dev_flops_per_sec = flops_per_sec / world_size
logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec})
if self.gpu_flops_available:
mfu = dev_flops_per_sec / self.gpu_flops_available
logger.log_metrics({'throughput/device/mfu': mfu})
# Compute flops stats if model has flops_per_batch
composer_model = state.model
if not isinstance(composer_model, ComposerModel):
composer_model = composer_model.module # Pass through DDP wrapping
if hasattr(composer_model, 'flops_per_batch'):
model_flops_per_batch = composer_model.flops_per_batch # type: ignore
if not isinstance(model_flops_per_batch, Callable):
raise TypeError('flops_per_batch must a callable accepting a batch and '
f'returning an int or float. Instead, got {type(model_flops_per_batch)}.')
device_flops_per_batch = model_flops_per_batch(state.batch)

# Sum flops across all ranks since each rank computes the flops for its own batch
flops_per_batch_tensor = state.device.tensor_to_device(
torch.tensor(device_flops_per_batch, dtype=torch.float))
dist.all_reduce(flops_per_batch_tensor, reduce_operation='SUM')
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
flops_per_batch = flops_per_batch_tensor.item()
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

self.history_flops.append(flops_per_batch)

# Log the flops throughput
if len(self.history_flops) == self.history_flops.maxlen:
world_size = dist.get_world_size()
elapsed_flops = sum(self.history_flops) - self.history_flops[0]
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
flops_per_sec = elapsed_flops / elapsed_wct
device_flops_per_sec = flops_per_sec / world_size
logger.log_metrics({'throughput/flops_per_sec': flops_per_sec})
logger.log_metrics({'throughput/device/flops_per_sec': device_flops_per_sec})
if self.gpu_flops_available:
mfu = device_flops_per_sec / self.gpu_flops_available
logger.log_metrics({'throughput/device/mfu': mfu})

# Log the time
# `state.timestamp` excludes any time spent in evaluation
Expand Down