diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml index 1e614aaceaa3..d67ff6dc59cb 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml @@ -72,6 +72,7 @@ model: persist_layer_norm: True # Use of persistent fused layer norm kernel. gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) bias_gelu_fusion: True # Use a kernel that fuses the bias addition from weight matrices with the subsequent gelu activation. + bias_activation_fusion: True # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. bias: True # Whether to use bias terms in all weight matrices. diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 3e536caf445a..201245778985 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -147,7 +147,10 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode activations_checkpoint_num_layers=self.cfg.get('activations_checkpoint_num_layers', 1), layernorm_epsilon=self.cfg.get('layernorm_epsilon', 1e-5), persist_layer_norm=self.cfg.get('persist_layer_norm', False), - bias_gelu_fusion=self.cfg.get('bias_gelu_fusion', True), + bias_activation_fusion=( + (self.cfg.get('bias_gelu_fusion', True) and self.cfg.get('activation', 'gelu') == 'gelu') + or (self.cfg.get('bias_activation_fusion', True) and self.cfg.get('activation', 'gelu') == 'geglu') + ), bias_dropout_add_fusion=self.cfg.get('bias_dropout_add_fusion', True), masked_softmax_fusion=self.cfg.get('masked_softmax_fusion', True), onnx_safe=self.cfg.get('onnx_safe', False), diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_geglu.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_geglu.py new file mode 100644 index 000000000000..010034157b6b --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_geglu.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import bias_gelu, bias_gelu_back + +try: + from apex._autocast_utils import _cast_if_autocast_enabled + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + + +@torch.jit.script +def bias_geglu(bias, y, bias_2, y_2): + x_2 = bias_2 + y_2 + return bias_gelu(bias, y) * x_2 + + +@torch.jit.script +def bias_geglu_back(g, bias, y, bias_2, y_2): + x_2 = bias_2 + y_2 + return bias_gelu_back(g, bias, y) * x_2, bias_gelu(bias, y) * g + + +class GeGLUFunction(torch.autograd.Function): + @staticmethod + # bias and bias_2 are optional arguments + def forward(ctx, input, bias, input_2, bias_2): + ctx.save_for_backward(input, bias, input_2, bias_2) + return bias_geglu(bias, input, bias_2, input_2) + + @staticmethod + def backward(ctx, grad_output): + input, bias, input_2, bias_2 = ctx.saved_tensors + tmp, tmp2 = bias_geglu_back(grad_output, bias, input, bias_2, input_2) + return tmp, tmp, tmp2, tmp2 + + +def fused_bias_geglu(input, bias, input_2, bias_2): + args = _cast_if_autocast_enabled(input, bias, input_2, bias_2) + with torch.cuda.amp.autocast(enabled=False): + return GeGLUFunction.apply(*args) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index e369bfc5fbab..29d1ac3adf36 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -420,7 +420,7 @@ def __init__( layernorm_epsilon=layernorm_epsilon, hidden_dropout=hidden_dropout, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_gelu_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, @@ -451,7 +451,7 @@ def __init__( layernorm_epsilon=layernorm_epsilon, hidden_dropout=hidden_dropout, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_gelu_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index b7b80a951e88..ae8d261edd06 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -64,7 +64,7 @@ def get_decoder_model( activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -121,7 +121,7 @@ def get_decoder_model( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, @@ -155,7 +155,7 @@ def get_decoder_model( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index 39858fe3b4d2..7b83601f3e59 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -63,7 +63,7 @@ def get_encoder_model( activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -120,7 +120,7 @@ def get_encoder_model( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, @@ -154,7 +154,7 @@ def get_encoder_model( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index 6a564f6f496d..068aeabc5e4d 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -66,7 +66,7 @@ def __init__( activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -124,7 +124,7 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 9c1865eedbd5..4d82d95fedde 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -63,7 +63,7 @@ def __init__( activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -121,7 +121,7 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py index 682aab39902f..665b9fa8aeec 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py @@ -153,7 +153,7 @@ def __init__( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, @@ -214,7 +214,7 @@ def __init__( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, @@ -255,7 +255,7 @@ def __init__( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py b/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py index 66f1c1af4a65..227653dee092 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py @@ -60,7 +60,7 @@ def __init__( activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -115,7 +115,7 @@ def __init__( hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, @@ -323,7 +323,7 @@ def __init__( activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -378,7 +378,7 @@ def __init__( hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index 8ee263920c01..ab915ae6ff92 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -101,7 +101,7 @@ def __init__( activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, openai_gelu=False, @@ -183,7 +183,7 @@ def __init__( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, @@ -248,7 +248,7 @@ def __init__( activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 3839e62e1df0..97117d723ad2 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -26,6 +26,7 @@ bias_dropout_add_fused_train, dropout_add, ) +from nemo.collections.nlp.modules.common.megatron.fused_bias_geglu import fused_bias_geglu from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType @@ -112,7 +113,7 @@ def __init__( hidden_size, ffn_hidden_size, use_cpu_initialization=False, - bias_gelu_fusion=True, + bias_activation_fusion=True, openai_gelu=False, onnx_safe=False, activation='gelu', @@ -157,13 +158,12 @@ def __init__( use_cpu_initialization=use_cpu_initialization, bias=bias, ) - glu_activation_family = True - else: - glu_activation_family = False - if glu_activation_family and bias_gelu_fusion: + glu_activation_family = activation in ['reglu', 'swiglu'] + + if glu_activation_family and bias_activation_fusion: raise ValueError( - f"Cannot use bias_gelu_fusion with {activation} activation. Please turn bias gelu fusion off." + f"Cannot use bias_activation_fusion with {activation} activation. Please turn bias gelu fusion off." ) if glu_activation_family and openai_gelu: @@ -176,14 +176,14 @@ def __init__( f"Cannot use onnx_safe with specificed activation function : {activation} Please turn onnx safe off." ) - if bias_gelu_fusion and not bias: + if bias_activation_fusion and not bias: raise ValueError( - f"Cannot use bias_gelu_fusion without bias terms. Please set bias=True or bias_gelu_fusion=False." + f"Cannot use bias_activation_fusion without bias terms. Please set bias=True or bias_activation_fusion=False." ) else: glu_activation_family = False - self.bias_gelu_fusion = bias_gelu_fusion + self.bias_activation_fusion = bias_activation_fusion if activation in ["gelu", "geglu"]: self.activation_func = F.gelu @@ -226,6 +226,16 @@ def forward(self, hidden_states): if self.activation in ['geglu', 'reglu', 'swiglu']: intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states) + + if self.bias_activation_fusion: + if self.activation == 'gelu': + intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = fused_bias_geglu( + intermediate_parallel, bias_parallel, intermediate_parallel_2, bias_parallel_2 + ) + + elif self.activation in ['geglu', 'reglu', 'swiglu']: if bias_parallel is not None: intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) * ( intermediate_parallel_2 + bias_parallel_2 @@ -233,8 +243,6 @@ def forward(self, hidden_states): else: intermediate_parallel = self.activation_func(intermediate_parallel) * intermediate_parallel_2 - elif self.bias_gelu_fusion and self.activation == 'gelu': - intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) else: if bias_parallel is not None: intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) @@ -952,7 +960,7 @@ def __init__( bias_dropout_fusion=True, persist_layer_norm=False, use_cpu_initialization=False, - bias_gelu_fusion=True, + bias_activation_fusion=True, openai_gelu=False, onnx_safe=False, masked_softmax_fusion=True, @@ -1134,7 +1142,7 @@ def __init__( hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, openai_gelu=openai_gelu, onnx_safe=onnx_safe, activation=activation, @@ -1426,7 +1434,7 @@ def __init__( relative_attention_num_buckets=32, relative_attention_max_distance=128, use_cpu_initialization=False, - bias_gelu_fusion=True, + bias_activation_fusion=True, bias_dropout_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, @@ -1498,7 +1506,7 @@ def build_layer(layer_number): relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, use_cpu_initialization=use_cpu_initialization, - bias_gelu_fusion=bias_gelu_fusion, + bias_activation_fusion=bias_activation_fusion, bias_dropout_fusion=bias_dropout_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm,