Skip to content

Commit

Permalink
Fixed Apex guard when imported classes are used for default values (#…
Browse files Browse the repository at this point in the history
…3700)

* 1. Added ApexGuardDefaults class that provides None attributes when an Apex class is missing.
2. Fixed Apex guard.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1.Added more guard.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added another guard.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
3 people authored Feb 17, 2022
1 parent 31cf580 commit 1b89a70
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/nlp/modules/common/megatron/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

"""Gradient clipping."""

import amp_C
import torch
from torch._six import inf

from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared

try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer
from nemo.collections.nlp.modules.common.megatron.utils import (
ApexGuardDefaults,
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
Expand All @@ -35,6 +36,9 @@
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

# fake missing classes with None attributes
AttnMaskType = ApexGuardDefaults()


def get_language_model(
hidden_size,
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func, erf_gelu

try:
from apex.transformer import parallel_state, tensor_parallel
Expand All @@ -39,6 +39,8 @@
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

# fake missing classes with None attributes
AttnMaskType = AttnType = LayerType = ApexGuardDefaults()

""" We use the following notation throughout this file:
h: hidden size
Expand Down Expand Up @@ -421,7 +423,8 @@ def __init__(
self.layer_number = layer_number
self.layer_type = layer_type

self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm # if true apply residual connection post layer norm (like original bert)
# if true apply residual connection post layer norm (like original bert)
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm

self.fp32_residual_connection = fp32_residual_connection # if true move residual connections to fp32

Expand Down
12 changes: 12 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
HAVE_APEX = False


class ApexGuardDefaults(object):
"""
This class can be used to replace missing classes when apex is missing.
"""

def __init__(self):
super().__init__()

def __getattr__(self, item):
return None


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
Expand Down

0 comments on commit 1b89a70

Please sign in to comment.