diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index c31be8a9109e..0149af26db21 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -590,8 +590,8 @@ def __init__( def optimizer_step( self, - model: Union["pl.LightningModule", torch.nn.Module], optimizer: torch.optim.Optimizer, + model: Union["pl.LightningModule", torch.nn.Module], optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any,