Skip to content

Commit

Permalink
Allows the default data step to operate on more types
Browse files Browse the repository at this point in the history
than just dictionaries

Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
  • Loading branch information
jomitchellnv committed Jul 9, 2024
1 parent 339b0b6 commit 5971ff5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT:
# Canonical case.
case batch:
pass
# If the dataloader_iter is empty, return None.
# If the dataloader_iter is empty, return a ValueError.
case _:
batch = None

if batch is not None:
return move_data_to_device(batch, torch.cuda.current_device())
else:
raise ValueError("No valid batch found from dataloader_iter.")
raise ValueError("None returned from dataloader.")


def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor:
Expand Down

0 comments on commit 5971ff5

Please sign in to comment.