-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature to automatically choose batch size #1615
Comments
I'd recommend a proper binsearch instead, so more like:
Alternative mode: just use the highest power of 2 batch size that fits. |
I am doing this right now def max_gpu_batch_size(
dataset: data.TextDataset,
finetuner: pl.LightningModule,
n_samples: int = 128,
device: Union[torch.device, int] = 0,
) -> int:
"""
Tries to find a maximal batch size for a device, assuming only that the memory usage of the
model and the total available memory are both stable.
Should be reliable, but slow, you probably only want to run it once.
"""
device = torch.device(device) # type: ignore
device_max_mem = torch.cuda.get_device_properties(device.index).total_memory
def test_run(batch_size):
logger.debug(f"Trying a run with batch size {batch_size}")
with tempfile.TemporaryDirectory() as temp_dir:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
loader = data.TextLoader(dataset, batch_size=batch_size)
trainer = pl.Trainer(
default_save_path=temp_dir,
overfit_pct=n_samples / len(loader),
gpus=[device.index],
max_epochs=2,
)
try:
trainer.fit(finetuner, train_dataloader=loader)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logger.debug("Exceeded memory capacity")
return None
else:
raise e
usage = torch.cuda.max_memory_allocated(device)
logger.debug(f"Registered usage: {usage} / {device_max_mem} B")
return usage
# Find a majoration of max batch size as a power of two
usage_with_min_size = 0
for exponent in range(math.floor(math.log2(n_samples)) + 1):
max_size = 2 ** exponent
usage_with_max_size = test_run(max_size)
if usage_with_max_size is None:
break
# This will only change as long as we don't break out, at which point it will
# equal the usage for the previous test run
usage_with_min_size = usage_with_max_size
if usage_with_max_size is not None:
logger.warning(
f"Ran out of examples without finding a match batch size (max tried: {max_size})"
", you probably want to try with more examples"
)
# Bissect to find the max batch size
min_size = max_size // 2
while max_size - min_size > 1:
try_size = (max_size + min_size) // 2
usage_with_try_size = test_run(try_size)
if usage_with_try_size is None:
max_size = try_size
else:
min_size = try_size
usage_with_min_size = usage_with_try_size
logger.debug(
f"Mem usage with inferred batch size: {usage_with_min_size} / {device_max_mem} B"
)
return min_size However, I usually have to downsize it, since in distributed mode, you have the additional requirement that the device batch size should be a multiple of the number of devices used. |
Sorry to comment in this closed issue, but I am kind of confused in the |
@ma-batita I am guessing you've found your answer by now, but if not (and for anybody else), the docs explain this pretty clearly here: https://pytorch-lightning.readthedocs.io/en/latest/advanced/training_tricks.html#batch-size-finder |
Let's add a flag:
This should do binary search on batch size:
And so on until we find the optimal batch size. At this point log it so the user knows (including tensorboard), and continue training with the new batch size.
Ideally the user fixes the batch size in future runs to tune the learning rate.
The text was updated successfully, but these errors were encountered: