From a7735905214053b8273004956cb0bb6f62247a59 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 13 Oct 2022 16:57:11 +0100 Subject: [PATCH] Add EMA support to NeMo (#4764) * Added Base files Signed-off-by: SeanNaren * Some refactors, swap to using MNIST Lnet Signed-off-by: SeanNaren * Add a few more tests, allow the callback to be set via the exp manager Signed-off-by: SeanNaren * Actually run validation for testing Signed-off-by: SeanNaren * Run isort Signed-off-by: SeanNaren * Add test for saving state/fix saving state Signed-off-by: SeanNaren * Use dummy model Signed-off-by: SeanNaren * Fix test Signed-off-by: SeanNaren * Add copyright Signed-off-by: SeanNaren * Support saving separate EMA weight module Signed-off-by: SeanNaren * Add standalone functionality/logging Signed-off-by: SeanNaren * Expose more parameters Signed-off-by: SeanNaren * Modify to allow option to replace validation Signed-off-by: SeanNaren * Add jenkins test, formatting Signed-off-by: SeanNaren * Pin Transformers version to fix CI (#4955) * Pin transformers version in CI to prevent offline tokenizer loading error Signed-off-by: SeanNaren * Drop version Signed-off-by: SeanNaren * Disable offline temporarily Signed-off-by: SeanNaren * Disable offline temporarily Signed-off-by: SeanNaren * Enable offline Signed-off-by: SeanNaren Signed-off-by: SeanNaren * Add cherry-pick action (#4958) (#4961) * add cherry-pick action Signed-off-by: ericharper * Pin Transformers version to fix CI (#4955) * Pin transformers version in CI to prevent offline tokenizer loading error Signed-off-by: SeanNaren * Drop version Signed-off-by: SeanNaren * Disable offline temporarily Signed-off-by: SeanNaren * Disable offline temporarily Signed-off-by: SeanNaren * Enable offline Signed-off-by: SeanNaren Signed-off-by: SeanNaren Signed-off-by: ericharper Signed-off-by: SeanNaren Co-authored-by: Sean Naren Signed-off-by: ericharper Signed-off-by: SeanNaren Co-authored-by: Eric Harper Co-authored-by: Sean Naren Signed-off-by: SeanNaren * Fix changelog builder (#4962) (#4963) Signed-off-by: smajumdar Signed-off-by: smajumdar Signed-off-by: smajumdar Signed-off-by: SeanNaren * fix cherry pick workflow (#4964) (#4965) Signed-off-by: ericharper Signed-off-by: ericharper Signed-off-by: ericharper Co-authored-by: Eric Harper Signed-off-by: SeanNaren * reorder model check (#4959) (#4967) Signed-off-by: nithinraok Signed-off-by: nithinraok Signed-off-by: nithinraok Co-authored-by: Nithin Rao Signed-off-by: SeanNaren * check for active conda environment (#4970) (#4971) Signed-off-by: SeanNaren * [TTS] fix broken tutorial for MixerTTS. (#4949) (#4976) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: SeanNaren * Checkpoint averaging class fix (#4946) * 1. Added args.class_path to provide it externally. Signed-off-by: Micha Livne * 1. Fixed style. Signed-off-by: Micha Livne Signed-off-by: Micha Livne Signed-off-by: SeanNaren * Add ability to give seperate datasets for test, train and validation (#4798) * Add ability to give seperate datasets for test, train and validation * Addressed Sandeeps comments * Addressed Sandeeps comments * Add ability to give seperate datasets for test, train and validation * Add ability to give seperate datasets for test, train and validation * Addressed review comments * Bug fix for common dataset utils * Add CI tests Signed-off-by: shanmugamr1992 * Reformat code Signed-off-by: shanmugamr1992 * Bug fix Signed-off-by: shanmugamr1992 * Bug fix * Bug Fix * Bug Fix * Update Jenkinsfile * Addressed comments * Addressed Eriks comments. * Addressed Sandeep * Update Jenkinsfile * Update Jenkinsfile * Update dataset_utils.py * Update Jenkinsfile * Update Jenkinsfile * Use GPT CI config Signed-off-by: MaximumEntropy Signed-off-by: shanmugamr1992 Signed-off-by: MaximumEntropy Co-authored-by: MaximumEntropy Signed-off-by: SeanNaren * fix label models restoring issue from wrighted cross entropy (#4968) (#4975) Signed-off-by: nithinraok Signed-off-by: nithinraok Signed-off-by: nithinraok Co-authored-by: Nithin Rao Signed-off-by: SeanNaren * Add simple pre-commit file (#4983) * Add simple pre-commit file Signed-off-by: SeanNaren * Exclude docs folder Signed-off-by: SeanNaren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: SeanNaren * Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit 053bd5ba579537a5f311b431871c21f3381b43eb. Signed-off-by: SeanNaren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: SeanNaren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: SeanNaren * Import pycuda.autoprimaryctx or pycuda.autoinit to init pycuda execution environment (#4951) Signed-off-by: Jin Li Signed-off-by: Jin Li Co-authored-by: Somshubra Majumdar Signed-off-by: SeanNaren * Adding speaker embedding conditioning in fastpitch (#4986) Signed-off-by: subhankar-ghosh Signed-off-by: subhankar-ghosh Signed-off-by: SeanNaren * Fix ASR issues (#4984) (#4991) * Fix ASR issues Signed-off-by: smajumdar * Revert fix Signed-off-by: smajumdar Signed-off-by: smajumdar Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: SeanNaren * Fix current tests Signed-off-by: SeanNaren * More test coverage Signed-off-by: SeanNaren * Address reviews Signed-off-by: SeanNaren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address review Signed-off-by: SeanNaren * Drop bf16 test Signed-off-by: SeanNaren * Address review Signed-off-by: SeanNaren * remove print Signed-off-by: SeanNaren * Add bf16 Signed-off-by: SeanNaren Signed-off-by: SeanNaren Signed-off-by: ericharper Signed-off-by: smajumdar Signed-off-by: nithinraok Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Micha Livne Signed-off-by: shanmugamr1992 Signed-off-by: MaximumEntropy Signed-off-by: Jin Li Signed-off-by: subhankar-ghosh Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Eric Harper Co-authored-by: Somshubra Majumdar Co-authored-by: Nithin Rao Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Micha Livne Co-authored-by: shanmugamr1992 <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: MaximumEntropy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: liji-nv <59594262+liji-nv@users.noreply.github.com> Co-authored-by: Subhankar Ghosh --- Jenkinsfile | 14 + nemo/collections/common/callbacks/__init__.py | 1 + nemo/collections/common/callbacks/ema.py | 184 ++++++++ nemo/utils/exp_manager.py | 57 ++- tests/collections/common/test_ema.py | 405 ++++++++++++++++++ 5 files changed, 654 insertions(+), 7 deletions(-) create mode 100644 nemo/collections/common/callbacks/ema.py create mode 100644 tests/collections/common/test_ema.py diff --git a/Jenkinsfile b/Jenkinsfile index 464dcd2013af..c8add78216a7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -233,6 +233,20 @@ pipeline { } } + stage('Speech to Text EMA') { + steps { + sh 'python examples/asr/asr_ctc/speech_to_text_ctc.py \ + model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + trainer.devices=2 \ + trainer.accelerator="gpu" \ + +trainer.fast_dev_run=True \ + +exp_manager.ema.enable=True \ + exp_manager.exp_dir=examples/asr/speech_to_text_results' + sh 'rm -rf examples/asr/speech_to_text_results' + } + } + stage('L2: Speech to Text WPE - CitriNet') { steps { sh 'python examples/asr/asr_ctc/speech_to_text_ctc_bpe.py \ diff --git a/nemo/collections/common/callbacks/__init__.py b/nemo/collections/common/callbacks/__init__.py index 9ad5c9c85a5f..0cf495d94696 100644 --- a/nemo/collections/common/callbacks/__init__.py +++ b/nemo/collections/common/callbacks/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback +from nemo.collections.common.callbacks.ema import EMA diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py new file mode 100644 index 000000000000..58d6a1668ab2 --- /dev/null +++ b/nemo/collections/common/callbacks/ema.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +import warnings +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +import torch +from pytorch_lightning import Callback +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from nemo.utils import logging + +try: + import amp_C + + apex_available = True +except Exception: + apex_available = False + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + apply_ema_every_n_steps: Apply EMA every n global steps. + start_step: Start applying EMA from ``start_step`` global step onwards. + evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights. + Note this means that when saving the model, the validation metrics are calculated with the EMA weights. + save_ema_weights_in_callback_state: Enable saving ema weights in callback state. + This is not required when using NeMo as the experiment manager handles saving weights. + """ + + def __init__( + self, + decay: float, + apply_ema_every_n_steps: int = 1, + start_step: int = 0, + save_ema_weights_in_callback_state: bool = False, + evaluate_ema_weights_instead: bool = False, + ): + if not apex_available: + rank_zero_warn( + "EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation." + ) + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self._ema_model_weights: Optional[List[torch.Tensor]] = None + self._overflow_buf: Optional[torch.Tensor] = None + self._cur_step: Optional[int] = None + self._weights_buffer: Optional[List[torch.Tensor]] = None + self.apply_ema_every_n_steps = apply_ema_every_n_steps + self.start_step = start_step + self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state + self.evaluate_ema_weights_instead = evaluate_ema_weights_instead + self.decay = decay + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + logging.info('Creating EMA weights copy.') + if self._ema_model_weights is None: + self._ema_model_weights = [p.detach().clone() for p in pl_module.state_dict().values()] + # ensure that all the weights are on the correct device + self._ema_model_weights = [p.to(pl_module.device) for p in self._ema_model_weights] + self._overflow_buf = torch.IntTensor([0]).to(pl_module.device) + + def ema(self, pl_module: "pl.LightningModule") -> None: + if apex_available and pl_module.device.type == "cuda": + return self.apply_multi_tensor_ema(pl_module) + return self.apply_ema(pl_module) + + def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None: + model_weights = list(pl_module.state_dict().values()) + amp_C.multi_tensor_axpby( + 65536, # todo (sean): chunk size, should we expose? + self._overflow_buf, + [self._ema_model_weights, model_weights, self._ema_model_weights], + self.decay, + 1 - self.decay, + -1, + ) + + def apply_ema(self, pl_module: "pl.LightningModule") -> None: + for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._ema_model_weights): + diff = ema_weight.data - orig_weight.data + diff.mul_(1.0 - self.decay) + ema_weight.sub_(diff) + + def should_apply_ema(self, step: int) -> bool: + return step != self._cur_step and step >= self.start_step and step % self.apply_ema_every_n_steps == 0 + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if self.should_apply_ema(trainer.global_step): + self._cur_step = trainer.global_step + self.ema(pl_module) + + def state_dict(self) -> Dict[str, Any]: + if self.save_ema_weights_in_callback_state: + return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights) + return dict(cur_step=self._cur_step) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._cur_step = state_dict['cur_step'] + # when loading using NeMo, ema weights will be loaded by the experiment manager separately. + if self._ema_model_weights is None: + self._ema_model_weights = state_dict.get('ema_weights') + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + if trainer.ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__: + ext = checkpoint_callback.FILE_EXTENSION + if trainer.ckpt_path.endswith(f'-EMA{ext}'): + logging.info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = trainer.ckpt_path.replace(ext, f'-EMA{ext}') + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) + self._ema_model_weights = ema_state_dict['state_dict'].values() + del ema_state_dict + logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") + else: + warnings.warn( + "we were unable to find the associated EMA weights when re-loading, " + "training will start with new EMA weights.", + UserWarning, + ) + + def replace_model_weights(self, pl_module: "pl.LightningModule") -> None: + self._weights_buffer = [p.detach().clone().to('cpu') for p in pl_module.state_dict().values()] + new_state_dict = {k: v for k, v in zip(pl_module.state_dict().keys(), self._ema_model_weights)} + pl_module.load_state_dict(new_state_dict) + + def restore_original_weights(self, pl_module: "pl.LightningModule") -> None: + state_dict = pl_module.state_dict() + new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)} + pl_module.load_state_dict(new_state_dict) + del self._weights_buffer + + @property + def ema_initialized(self) -> bool: + return self._ema_model_weights is not None + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.replace_model_weights(pl_module) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.restore_original_weights(pl_module) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.replace_model_weights(pl_module) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.restore_original_weights(pl_module) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index b5eecb95dc87..5933b41e8cd0 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -36,7 +36,9 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.loops import TrainingEpochLoop from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.utilities import rank_zero_info +from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION from nemo.utils import logging, timers from nemo.utils.app_state import AppState @@ -95,6 +97,15 @@ class StepTimingParams: buffer_size: Optional[int] = 1 +@dataclass +class EMAParams: + enable: Optional[bool] = False + evaluate_ema_weights_instead: Optional[bool] = False + decay: Optional[float] = 0.999 + apply_ema_every_n_steps: Optional[int] = 1 + start_step: Optional[int] = 0 + + @dataclass class ExpManagerConfig: # Log dir creation parameters @@ -124,6 +135,7 @@ class ExpManagerConfig: log_global_rank_0_only: Optional[bool] = False # disable initial validation when resuming from a checkpoint saved during validation disable_validation_on_resume: Optional[bool] = True + ema: Optional[EMAParams] = EMAParams() class TimingCallback(Callback): @@ -332,6 +344,15 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo timing_callback = TimingCallback(timer_kwargs=cfg.step_timing_kwargs or {}) trainer.callbacks.insert(0, timing_callback) + if cfg.ema.enable: + ema_callback = EMA( + decay=cfg.ema.decay, + apply_ema_every_n_steps=cfg.ema.apply_ema_every_n_steps, + start_step=cfg.ema.start_step, + evaluate_ema_weights_instead=cfg.ema.evaluate_ema_weights_instead, + ) + trainer.callbacks.append(ema_callback) + if cfg.create_checkpoint_callback: configure_checkpointing( trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params @@ -695,12 +716,12 @@ class NeMoModelCheckpoint(ModelCheckpoint): def __init__( self, - always_save_nemo=False, - save_nemo_on_train_end=True, - save_best_model=False, - postfix=".nemo", - n_resume=False, - model_parallel_size=None, + always_save_nemo: bool = False, + save_nemo_on_train_end: bool = True, + save_best_model: bool = False, + postfix: str = ".nemo", + n_resume: bool = False, + model_parallel_size: int = None, **kwargs, ): # Parse and store "extended" parameters: save_best model and postfix. @@ -862,9 +883,31 @@ def _del_model_without_trainer(self, filepath: str) -> None: except: logging.info(f"Tried to remove checkpoint: {filepath} but failed.") + def _get_ema_callback(self, trainer) -> Optional[EMA]: + ema_callback = None + for callback in trainer.callbacks: + if isinstance(callback, EMA): + ema_callback = callback + return ema_callback + + def _save_checkpoint(self, trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + ema_callback = self._get_ema_callback(trainer) + if ema_callback is not None: + # save EMA copy of the model as well. + ema_callback.replace_model_weights(trainer.lightning_module) + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) + ema_callback.restore_original_weights(trainer.lightning_module) + + def _ema_format_filepath(self, filepath: str) -> str: + return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') + def configure_checkpointing( - trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig' + trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', ): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py new file mode 100644 index 000000000000..a4b955c64815 --- /dev/null +++ b/tests/collections/common/test_ema.py @@ -0,0 +1,405 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +from copy import deepcopy +from typing import Any, Dict, Union +from unittest import mock + +import pytest +import pytorch_lightning as pl +import torch +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from nemo.collections.common.callbacks import EMA +from nemo.core import ModelPT +from nemo.utils.exp_manager import exp_manager + + +class OnesDataset(torch.utils.data.Dataset): + def __init__(self, dataset_len): + super().__init__() + self.__dataset_len = dataset_len + + def __getitem__(self, *args): + return torch.ones(2) + + def __len__(self): + return self.__dataset_len + + +class ExampleModel(ModelPT): + def __init__(self, *args, **kwargs): + cfg = OmegaConf.structured({}) + super().__init__(cfg) + self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) + + def train_dataloader(self): + dataset = OnesDataset(16) + return torch.utils.data.DataLoader(dataset, batch_size=2) + + def val_dataloader(self): + dataset = OnesDataset(10) + return torch.utils.data.DataLoader(dataset, batch_size=2) + + def forward(self, batch): + output = self.l1(batch) + return torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device)) + + def validation_step(self, batch, batch_idx): + return self(batch) + + def training_step(self, batch, batch_idx): + return self(batch) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.1) + + def list_available_models(self): + pass + + def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): + pass + + def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): + pass + + def validation_epoch_end(self, loss): + self.log("val_loss", torch.stack(loss).mean()) + + +class TestEMAConfig: + @pytest.mark.unit + def test_ema_value(self): + with pytest.raises(MisconfigurationException, match="between 0 and 1"): + EMA(decay=2) + + @mock.patch('nemo.collections.common.callbacks.ema.apex_available', False) + def test_ema_apex_unavailable(self): + with pytest.warns(UserWarning, match="EMA has better performance when Apex is installed"): + EMA(decay=0.999) + + @pytest.mark.unit + @pytest.mark.run_only_on('GPU') + def test_ema_saved_state(self, tmpdir, caplog): + """Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly.""" + temp_path = os.path.join(tmpdir, 'saved_state') + + class TerminateCallback(Callback): + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + self.saved_ema_weights = ema_callback._ema_model_weights + self.pl_module_weights = list(pl_module.state_dict().values()) + raise SystemExit + + model = ExampleModel() + terminate_callback = TerminateCallback() + + trainer = Trainer( + max_epochs=2, + limit_val_batches=1, + limit_train_batches=16, + logger=False, + val_check_interval=0.5, + enable_checkpointing=False, + accelerator='gpu', + devices=1, + callbacks=[terminate_callback], + ) + exp_manager( + trainer, + { + "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "explicit_log_dir": str(temp_path), + "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, + }, + ) + with pytest.raises(SystemExit): + trainer.fit(model=model) + resume_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8.ckpt') + + model = ExampleModel() + + class CheckStateCallback(Callback): + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + weights = list(pl_module.state_dict().values()) + for x, y in zip(weights, terminate_callback.pl_module_weights): + assert torch.allclose(x.cpu(), y.cpu()) + for x, y in zip(ema_callback._ema_model_weights, terminate_callback.saved_ema_weights): + assert torch.allclose(x.cpu(), y.cpu()) + assert ema_callback._cur_step == 8 + + trainer = Trainer( + max_epochs=2, + limit_val_batches=0, + limit_train_batches=16, + logger=False, + enable_checkpointing=False, + accelerator='gpu', + devices=1, + ) + exp_manager( + trainer, + { + "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "explicit_log_dir": str(temp_path), + "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, + }, + ) + # add the callback after the exp manager has made modifications. + trainer.callbacks.append(CheckStateCallback()) + trainer.fit(model, ckpt_path=resume_path) + + # ensure we can resume from the EMA weights + ema_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8-EMA.ckpt') + + trainer = Trainer( + max_epochs=1, + limit_val_batches=0, + limit_train_batches=1, + logger=False, + enable_checkpointing=False, + accelerator='gpu', + devices=1, + ) + exp_manager( + trainer, + { + "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "explicit_log_dir": str(temp_path), + "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, + }, + ) + trainer.fit(model, ckpt_path=ema_path) + + # ensure that we warn when the EMA weights do not exist + os.remove(ema_path) + + trainer = Trainer( + max_epochs=1, + limit_val_batches=0, + limit_train_batches=1, + logger=False, + enable_checkpointing=False, + accelerator='gpu', + devices=1, + ) + exp_manager( + trainer, + { + "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "explicit_log_dir": str(temp_path), + "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, + }, + ) + with pytest.warns(UserWarning, match="we were unable to find the associated EMA weights when re-loading"): + trainer.fit(model, ckpt_path=resume_path) + + @pytest.mark.unit + @pytest.mark.run_only_on('GPU') + def test_exp_manager_ema_weights(self, tmpdir): + """Test to ensure that the exp manager adds the EMA callback, and we save an additional EMA checkpoint.""" + tmp_path = tmpdir / "exp_manager_test" + model = ExampleModel() + trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, accelerator='gpu', devices=1) + exp_manager( + trainer, + { + "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "explicit_log_dir": str(tmp_path), + "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, + }, + ) + assert any(isinstance(callback, EMA) for callback in trainer.callbacks) + trainer.fit(model) + + assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt") + ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt" + assert os.path.exists(ema_path) + + duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path)) + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_callback._ema_model_weights): + assert torch.allclose(saved_weight.cpu(), ema_weight.cpu()) + + @pytest.mark.unit + @pytest.mark.run_only_on('GPU') + def test_ema_save_in_callback(self, tmpdir): + """Test to ensure when `save_ema_weights_in_callback_state` is enabled, we save to the callback state.""" + temp_path = os.path.join(tmpdir, 'saved_state') + + model = ExampleModel() + + trainer = Trainer( + max_epochs=2, + limit_val_batches=1, + limit_train_batches=16, + logger=False, + val_check_interval=0.5, + enable_checkpointing=False, + accelerator='gpu', + devices=1, + callbacks=[EMA(decay=0.999, save_ema_weights_in_callback_state=True, evaluate_ema_weights_instead=True)], + ) + exp_manager( + trainer, + {"explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},}, + ) + trainer.fit(model=model) + + resume_path = os.path.join(temp_path, "checkpoints/epoch=0-step=8.ckpt") + callback = EMA(decay=0.999, save_ema_weights_in_callback_state=True) + + class AssertCallback(Callback): + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert callback._ema_model_weights is not None + + model = ExampleModel() + + trainer = Trainer( + max_epochs=2, + limit_val_batches=1, + limit_train_batches=16, + logger=False, + val_check_interval=0.5, + enable_checkpointing=False, + accelerator='gpu', + devices=1, + callbacks=[callback, AssertCallback()], + ) + trainer.fit(model, ckpt_path=resume_path) + + +class TestEMATrain: + @pytest.mark.unit + @pytest.mark.parametrize("precision", [16, "bf16", 32]) + @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) + @pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False]) + @pytest.mark.parametrize("apex_available_mock", [True, False]) + @pytest.mark.run_only_on('GPU') + def test_ema_run_cuda( + self, + test_data_dir, + precision, + accumulate_grad_batches, + evaluate_ema_weights_instead, + apex_available_mock, + tmpdir, + ): + with mock.patch('nemo.collections.common.callbacks.ema.apex_available', apex_available_mock): + self.run_training_test( + accumulate_grad_batches=accumulate_grad_batches, + evaluate_ema_weights_instead=evaluate_ema_weights_instead, + accelerator='gpu', + precision=precision, + tmpdir=tmpdir, + ) + + @pytest.mark.unit + @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) + @pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False]) + @pytest.mark.run_only_on('GPU') + def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, evaluate_ema_weights_instead, tmpdir): + self.run_training_test( + accumulate_grad_batches=accumulate_grad_batches, + evaluate_ema_weights_instead=evaluate_ema_weights_instead, + accelerator='cpu', + precision=32, + tmpdir=tmpdir, + ) + + def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instead, accelerator, precision, tmpdir): + pl.seed_everything(123) + model = ExampleModel() + trainer = Trainer( + max_epochs=1, + precision=precision, + limit_train_batches=10, + limit_val_batches=10, + logger=False, + accumulate_grad_batches=accumulate_grad_batches, + num_sanity_val_steps=0, + enable_model_summary=False, + enable_checkpointing=False, + accelerator=accelerator, + devices=1, + ) + exp_manager( + trainer, + { + "ema": {"enable": True, "evaluate_ema_weights_instead": evaluate_ema_weights_instead}, + "explicit_log_dir": str(tmpdir), + "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, + }, + ) + # add the check callback after the exp manager has made modifications. + trainer.callbacks.append(EMAAssertCallback()) + trainer.fit(model=model, val_dataloaders=model.train_dataloader()) + + +class EMAAssertCallback(Callback): + def __init__(self): + self._before_calc_ema_weights = None + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + model_weights = list(pl_module.state_dict().values()) + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + for x, y in zip(model_weights, ema_callback._ema_model_weights): + assert torch.allclose(x, y) + + def on_train_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + ) -> None: + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + # saved for manual calculation of ema to compare against implementation + self._before_calc_ema_weights = deepcopy(ema_callback._ema_model_weights) + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if (batch_idx + 1) % trainer.accumulate_grad_batches != 0: + # skip assertion as ema weights are not updated. + return + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + decay = ema_callback.decay + expected_ema_weights = [] + for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._before_calc_ema_weights): + expected_ema_weight = orig_weight * (1 - decay) + ema_weight * decay + expected_ema_weights.append(expected_ema_weight) + + for actual_ema_weight, expected_ema_weight in zip(ema_callback._ema_model_weights, expected_ema_weights): + assert torch.allclose(actual_ema_weight, expected_ema_weight) + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + if ema_callback.evaluate_ema_weights_instead: + # todo (sean): shouldn't use the weights buffer to check original weights + self._original_weights = list(x.detach().clone() for x in ema_callback._weights_buffer) + if ema_callback.ema_initialized: + for ema_weights, module_weights in zip( + ema_callback._ema_model_weights, pl_module.state_dict().values() + ): + torch.allclose(ema_weights, module_weights) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + if ema_callback.evaluate_ema_weights_instead: + model_weights = list(pl_module.state_dict().values()) + if ema_callback.ema_initialized: + for orig_weights, module_weights in zip(self._original_weights, model_weights): + torch.allclose(orig_weights, module_weights.cpu())