Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EMA support to NeMo #4764

Merged
merged 48 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
e341879
Added Base files
SeanNaren Aug 17, 2022
5401e6a
Some refactors, swap to using MNIST Lnet
SeanNaren Aug 17, 2022
9abe158
Add a few more tests, allow the callback to be set via the exp manager
SeanNaren Aug 18, 2022
0331ee1
Actually run validation for testing
SeanNaren Aug 25, 2022
3c386f9
Run isort
SeanNaren Aug 25, 2022
856330e
Add test for saving state/fix saving state
SeanNaren Aug 30, 2022
100b58c
Merge branch 'main' into feature/ema
SeanNaren Sep 9, 2022
7a36bcd
Use dummy model
SeanNaren Sep 9, 2022
3c458a4
Fix test
SeanNaren Sep 9, 2022
672518d
Merge branch 'main' into feature/ema
SeanNaren Sep 13, 2022
484b490
Add copyright
SeanNaren Sep 13, 2022
ac1b9ba
Support saving separate EMA weight module
SeanNaren Sep 14, 2022
1d119b9
Add standalone functionality/logging
SeanNaren Sep 15, 2022
56212e3
Merge branch 'main' into feature/ema
SeanNaren Sep 15, 2022
fb387ca
Expose more parameters
SeanNaren Sep 18, 2022
ffad454
Modify to allow option to replace validation
SeanNaren Sep 21, 2022
5f79265
Add jenkins test, formatting
SeanNaren Sep 21, 2022
8648d76
Pin Transformers version to fix CI (#4955)
SeanNaren Sep 20, 2022
63eff7e
Add cherry-pick action (#4958) (#4961)
github-actions[bot] Sep 20, 2022
8e3826c
Fix changelog builder (#4962) (#4963)
titu1994 Sep 21, 2022
1409c45
fix cherry pick workflow (#4964) (#4965)
github-actions[bot] Sep 21, 2022
19df5de
reorder model check (#4959) (#4967)
github-actions[bot] Sep 21, 2022
9d60d7a
check for active conda environment (#4970) (#4971)
github-actions[bot] Sep 21, 2022
8d15cc3
[TTS] fix broken tutorial for MixerTTS. (#4949) (#4976)
github-actions[bot] Sep 21, 2022
2d33bcd
Checkpoint averaging class fix (#4946)
michalivne Sep 21, 2022
7593043
Add ability to give seperate datasets for test, train and validation …
shanmugamr1992 Sep 21, 2022
2f1f49a
fix label models restoring issue from wrighted cross entropy (#4968) …
github-actions[bot] Sep 21, 2022
88ded3f
Add simple pre-commit file (#4983)
SeanNaren Sep 22, 2022
3105553
Import pycuda.autoprimaryctx or pycuda.autoinit to init pycuda execut…
liji-nv Sep 22, 2022
504af09
Adding speaker embedding conditioning in fastpitch (#4986)
subhankar-ghosh Sep 22, 2022
9d51a4d
Fix ASR issues (#4984) (#4991)
github-actions[bot] Sep 23, 2022
5fda255
Fix current tests
SeanNaren Sep 23, 2022
ff61f12
Merge branch 'main' into feature/ema
SeanNaren Sep 23, 2022
b3c0be7
More test coverage
SeanNaren Sep 23, 2022
32799b2
Address reviews
SeanNaren Sep 28, 2022
a8b23d3
Merge branch 'main' into feature/ema
SeanNaren Sep 28, 2022
aa18207
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2022
919c9b5
Address review
SeanNaren Sep 28, 2022
ca96104
Merge remote-tracking branch 'fork/feature/ema' into feature/ema
SeanNaren Sep 28, 2022
a44ad98
Drop bf16 test
SeanNaren Sep 28, 2022
0514a99
Address review
SeanNaren Sep 28, 2022
8fe1275
remove print
SeanNaren Sep 28, 2022
18d68c8
Merge branch 'main' into feature/ema
SeanNaren Sep 29, 2022
26df436
Merge branch 'main' into feature/ema
SeanNaren Sep 29, 2022
12c97ba
Merge branch 'main' into feature/ema
SeanNaren Sep 30, 2022
a8a2a81
Merge branch 'main' into feature/ema
SeanNaren Oct 7, 2022
f19dd4e
Merge branch 'main' into feature/ema
SeanNaren Oct 13, 2022
e8aab0c
Add bf16
SeanNaren Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ pipeline {
}
}

stage('Speech to Text EMA') {
steps {
sh 'python examples/asr/asr_ctc/speech_to_text_ctc.py \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
trainer.devices=2 \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=True \
+exp_manager.ema.enable=True \
exp_manager.exp_dir=examples/asr/speech_to_text_results'
sh 'rm -rf examples/asr/speech_to_text_results'
}
}

stage('L2: Speech to Text WPE - CitriNet') {
steps {
sh 'python examples/asr/asr_ctc/speech_to_text_ctc_bpe.py \
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/common/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
from nemo.collections.common.callbacks.ema import EMA
184 changes: 184 additions & 0 deletions nemo/collections/common/callbacks/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import warnings
from typing import Any, Dict, List, Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

from nemo.utils import logging

try:
import amp_C

apex_available = True
except Exception:
apex_available = False


class EMA(Callback):
"""
Implements Exponential Moving Averaging (EMA).

When training a model, this callback will maintain moving averages of the trained parameters.
When evaluating, we use the moving averages copy of the trained parameters.
When saving, we save an additional set of parameters with the prefix `ema`.

SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
Args:
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
apply_ema_every_n_steps: Apply EMA every n global steps.
start_step: Start applying EMA from ``start_step`` global step onwards.
evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights.
Note this means that when saving the model, the validation metrics are calculated with the EMA weights.
save_ema_weights_in_callback_state: Enable saving ema weights in callback state.
This is not required when using NeMo as the experiment manager handles saving weights.
"""

def __init__(
self,
decay: float,
apply_ema_every_n_steps: int = 1,
start_step: int = 0,
save_ema_weights_in_callback_state: bool = False,
evaluate_ema_weights_instead: bool = False,
):
if not apex_available:
rank_zero_warn(
"EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation."
)
if not (0 <= decay <= 1):
raise MisconfigurationException("EMA decay value must be between 0 and 1")
self._ema_model_weights: Optional[List[torch.Tensor]] = None
self._overflow_buf: Optional[torch.Tensor] = None
self._cur_step: Optional[int] = None
self._weights_buffer: Optional[List[torch.Tensor]] = None
self.apply_ema_every_n_steps = apply_ema_every_n_steps
self.start_step = start_step
self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state
self.evaluate_ema_weights_instead = evaluate_ema_weights_instead
self.decay = decay

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
logging.info('Creating EMA weights copy.')
if self._ema_model_weights is None:
self._ema_model_weights = [p.detach().clone() for p in pl_module.state_dict().values()]
# ensure that all the weights are on the correct device
self._ema_model_weights = [p.to(pl_module.device) for p in self._ema_model_weights]
self._overflow_buf = torch.IntTensor([0]).to(pl_module.device)

def ema(self, pl_module: "pl.LightningModule") -> None:
if apex_available and pl_module.device.type == "cuda":
return self.apply_multi_tensor_ema(pl_module)
return self.apply_ema(pl_module)

def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None:
model_weights = list(pl_module.state_dict().values())
amp_C.multi_tensor_axpby(
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
65536, # todo (sean): chunk size, should we expose?
self._overflow_buf,
[self._ema_model_weights, model_weights, self._ema_model_weights],
self.decay,
1 - self.decay,
-1,
)

def apply_ema(self, pl_module: "pl.LightningModule") -> None:
for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._ema_model_weights):
diff = ema_weight.data - orig_weight.data
diff.mul_(1.0 - self.decay)
ema_weight.sub_(diff)

def should_apply_ema(self, step: int) -> bool:
return step != self._cur_step and step >= self.start_step and step % self.apply_ema_every_n_steps == 0

def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
if self.should_apply_ema(trainer.global_step):
self._cur_step = trainer.global_step
self.ema(pl_module)

def state_dict(self) -> Dict[str, Any]:
if self.save_ema_weights_in_callback_state:
return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights)
return dict(cur_step=self._cur_step)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._cur_step = state_dict['cur_step']
# when loading using NeMo, ema weights will be loaded by the experiment manager separately.
if self._ema_model_weights is None:
self._ema_model_weights = state_dict.get('ema_weights')

def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
checkpoint_callback = trainer.checkpoint_callback

SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if trainer.ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__:
ext = checkpoint_callback.FILE_EXTENSION
if trainer.ckpt_path.endswith(f'-EMA{ext}'):
logging.info(
"loading EMA based weights. "
"The callback will treat the loaded EMA weights as the main weights"
" and create a new EMA copy when training."
)
return
ema_path = trainer.ckpt_path.replace(ext, f'-EMA{ext}')
if os.path.exists(ema_path):
ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
self._ema_model_weights = ema_state_dict['state_dict'].values()
del ema_state_dict
logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.")
else:
warnings.warn(
"we were unable to find the associated EMA weights when re-loading, "
"training will start with new EMA weights.",
UserWarning,
)

def replace_model_weights(self, pl_module: "pl.LightningModule") -> None:
self._weights_buffer = [p.detach().clone().to('cpu') for p in pl_module.state_dict().values()]
new_state_dict = {k: v for k, v in zip(pl_module.state_dict().keys(), self._ema_model_weights)}
pl_module.load_state_dict(new_state_dict)

def restore_original_weights(self, pl_module: "pl.LightningModule") -> None:
state_dict = pl_module.state_dict()
new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)}
pl_module.load_state_dict(new_state_dict)
del self._weights_buffer

@property
def ema_initialized(self) -> bool:
return self._ema_model_weights is not None

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema_initialized and self.evaluate_ema_weights_instead:
self.replace_model_weights(pl_module)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema_initialized and self.evaluate_ema_weights_instead:
self.restore_original_weights(pl_module)

def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema_initialized and self.evaluate_ema_weights_instead:
self.replace_model_weights(pl_module)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema_initialized and self.evaluate_ema_weights_instead:
self.restore_original_weights(pl_module)
57 changes: 50 additions & 7 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.loops import TrainingEpochLoop
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities import rank_zero_info

from nemo.collections.common.callbacks import EMA
from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
from nemo.utils import logging, timers
from nemo.utils.app_state import AppState
Expand Down Expand Up @@ -95,6 +97,15 @@ class StepTimingParams:
buffer_size: Optional[int] = 1


@dataclass
class EMAParams:
enable: Optional[bool] = False
evaluate_ema_weights_instead: Optional[bool] = False
decay: Optional[float] = 0.999
apply_ema_every_n_steps: Optional[int] = 1
start_step: Optional[int] = 0


@dataclass
class ExpManagerConfig:
# Log dir creation parameters
Expand Down Expand Up @@ -124,6 +135,7 @@ class ExpManagerConfig:
log_global_rank_0_only: Optional[bool] = False
# disable initial validation when resuming from a checkpoint saved during validation
disable_validation_on_resume: Optional[bool] = True
ema: Optional[EMAParams] = EMAParams()


class TimingCallback(Callback):
Expand Down Expand Up @@ -332,6 +344,15 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
timing_callback = TimingCallback(timer_kwargs=cfg.step_timing_kwargs or {})
trainer.callbacks.insert(0, timing_callback)

if cfg.ema.enable:
ema_callback = EMA(
decay=cfg.ema.decay,
apply_ema_every_n_steps=cfg.ema.apply_ema_every_n_steps,
start_step=cfg.ema.start_step,
evaluate_ema_weights_instead=cfg.ema.evaluate_ema_weights_instead,
)
trainer.callbacks.append(ema_callback)

if cfg.create_checkpoint_callback:
configure_checkpointing(
trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params
Expand Down Expand Up @@ -695,12 +716,12 @@ class NeMoModelCheckpoint(ModelCheckpoint):

def __init__(
self,
always_save_nemo=False,
save_nemo_on_train_end=True,
save_best_model=False,
postfix=".nemo",
n_resume=False,
model_parallel_size=None,
always_save_nemo: bool = False,
save_nemo_on_train_end: bool = True,
save_best_model: bool = False,
postfix: str = ".nemo",
n_resume: bool = False,
model_parallel_size: int = None,
**kwargs,
):
# Parse and store "extended" parameters: save_best model and postfix.
Expand Down Expand Up @@ -862,9 +883,31 @@ def _del_model_without_trainer(self, filepath: str) -> None:
except:
logging.info(f"Tried to remove checkpoint: {filepath} but failed.")

def _get_ema_callback(self, trainer) -> Optional[EMA]:
ema_callback = None
for callback in trainer.callbacks:
if isinstance(callback, EMA):
ema_callback = callback
return ema_callback

def _save_checkpoint(self, trainer, filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
ema_callback = self._get_ema_callback(trainer)
if ema_callback is not None:
# save EMA copy of the model as well.
ema_callback.replace_model_weights(trainer.lightning_module)
filepath = self._ema_format_filepath(filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
ema_callback.restore_original_weights(trainer.lightning_module)

def _ema_format_filepath(self, filepath: str) -> str:
return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}')


def configure_checkpointing(
trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig'
trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig',
):
""" Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
callback
Expand Down
Loading