Skip to content

Commit

Permalink
Apply garbage collection interval to validation steps (#6870) (#6872)
Browse files Browse the repository at this point in the history
* Apply garbage collection inverval to validation steps



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
4 people authored Jun 28, 2023
1 parent e9b0b11 commit 7e20750
Showing 1 changed file with 19 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
# The automatic garbage collector sould be disabled before training starts.
if self.gc_interval > 0:
gc.disable()
self.validation_global_step = 1

def _enable_nvidia_optimizations(self):
"These optimizations are present in NVIDIA NGC PyTorch Containers"
Expand Down Expand Up @@ -225,6 +226,16 @@ def on_train_start(self) -> None:
super().on_train_start()
self.init_global_step = self.trainer.global_step

def on_validation_start(self) -> None:
super().on_validation_start()
if self.gc_interval > 0:
gc.collect()

def on_validation_end(self) -> None:
super().on_validation_end()
if self.gc_interval > 0:
gc.collect()

def _build_vocab(self):
"""
Manipulate vocabulary (e.g., pad vocabulary for increased performance)/
Expand Down Expand Up @@ -373,6 +384,14 @@ def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unus
if self.gc_interval > 0 and (self.trainer.global_step % self.gc_interval == 0):
gc.collect()

def on_validation_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)

if self.gc_interval > 0:
if self.validation_global_step % self.gc_interval == 0:
gc.collect()
self.validation_global_step += 1

def setup_optimization(
self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down

0 comments on commit 7e20750

Please sign in to comment.