Skip to content

Commit

Permalink
export_utils bugfix (#5482)
Browse files Browse the repository at this point in the history
* export_utils bugfix (#5480)

* updated export_utils

Signed-off-by: David Mosallanezhad <dmosallanezh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: David Mosallanezhad <dmosallanezh@nvidia.com>
Co-authored-by: David Mosallanezhad <dmosallanezh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: David Mosallanezhad <dmosallanezh@nvidia.com>
Signed-off-by: Boris Fomitchev <borisfom@users.noreply.github.com>
Co-authored-by: David <amosalla@asu.edu>
Co-authored-by: David Mosallanezhad <dmosallanezh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Boris Fomitchev <borisfom@users.noreply.github.com>
  • Loading branch information
5 people authored Nov 23, 2022
1 parent bbed82f commit 0a2413d
Showing 1 changed file with 7 additions and 39 deletions.
46 changes: 7 additions & 39 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,42 +59,6 @@ def forward(self, x):
return F.linear(x, self.weight, self.bias), None


class ExportableMatchedScaleMaskSoftmax(nn.Module):
def __init__(self, mod):
super(ExportableMatchedScaleMaskSoftmax, self).__init__()
self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale)

def init_module(
self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale,
):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
self.softmax_in_fp32 = softmax_in_fp32
self.mask_func = mask_func
self.scale = scale

self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16

def forward(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()

if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
probs = probs * zero_attention_mask

if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs


def get_export_format(filename: str):
_, ext = os.path.splitext(filename)
try:
Expand Down Expand Up @@ -366,9 +330,13 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Returns:
exportable module
"""
# including the import here to avoid circular imports
from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax

mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale)

# disabling fusion for the MatchedScaleMaskSoftmax
mod = MatchedScaleMaskSoftmax(
n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
)
return mod


Expand Down Expand Up @@ -440,7 +408,7 @@ def script_module(m: nn.Module):
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
"MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax),
"MatchedScaleMaskSoftmax": wrap_module(None, replace_MatchedScaleMaskSoftmax),
}

script_replacements = {
Expand Down

0 comments on commit 0a2413d

Please sign in to comment.