Skip to content

Commit

Permalink
Fuse bias with geglu in ParallelMLP (#4213)
Browse files Browse the repository at this point in the history
* add code of fused_bias_geglu

* call fused_bias_geglu in ParallelMLP

* fix some bugs

* change biad_gelu_activation to bias_activation_fusion

* fix the setting of bias_actication_fusion for T5

* delete bias_gelu_fusion from T5 example config

* push reformatted files

* hto4h gemms fusion

* remove hto4h gemms fusion

* push reformatted files

* disable bias_activation_fusion while activation is not geglu

* add bias_activation_fusion in yaml config file

* add bias_gelu_fusion in T5 config yaml file to pass CI test

* change bias_gelu_fusion to bias_activation_fusion for T5 CI test

* recover latest change

Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
xrennvidia and MaximumEntropy authored Jun 3, 2022
1 parent 416d033 commit ff72f90
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ model:
persist_layer_norm: True # Use of persistent fused layer norm kernel.
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
bias_gelu_fusion: True # Use a kernel that fuses the bias addition from weight matrices with the subsequent gelu activation.
bias_activation_fusion: True # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function.
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
bias: True # Whether to use bias terms in all weight matrices.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode
activations_checkpoint_num_layers=self.cfg.get('activations_checkpoint_num_layers', 1),
layernorm_epsilon=self.cfg.get('layernorm_epsilon', 1e-5),
persist_layer_norm=self.cfg.get('persist_layer_norm', False),
bias_gelu_fusion=self.cfg.get('bias_gelu_fusion', True),
bias_activation_fusion=(
(self.cfg.get('bias_gelu_fusion', True) and self.cfg.get('activation', 'gelu') == 'gelu')
or (self.cfg.get('bias_activation_fusion', True) and self.cfg.get('activation', 'gelu') == 'geglu')
),
bias_dropout_add_fusion=self.cfg.get('bias_dropout_add_fusion', True),
masked_softmax_fusion=self.cfg.get('masked_softmax_fusion', True),
onnx_safe=self.cfg.get('onnx_safe', False),
Expand Down
57 changes: 57 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/fused_bias_geglu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import bias_gelu, bias_gelu_back

try:
from apex._autocast_utils import _cast_if_autocast_enabled

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False


@torch.jit.script
def bias_geglu(bias, y, bias_2, y_2):
x_2 = bias_2 + y_2
return bias_gelu(bias, y) * x_2


@torch.jit.script
def bias_geglu_back(g, bias, y, bias_2, y_2):
x_2 = bias_2 + y_2
return bias_gelu_back(g, bias, y) * x_2, bias_gelu(bias, y) * g


class GeGLUFunction(torch.autograd.Function):
@staticmethod
# bias and bias_2 are optional arguments
def forward(ctx, input, bias, input_2, bias_2):
ctx.save_for_backward(input, bias, input_2, bias_2)
return bias_geglu(bias, input, bias_2, input_2)

@staticmethod
def backward(ctx, grad_output):
input, bias, input_2, bias_2 = ctx.saved_tensors
tmp, tmp2 = bias_geglu_back(grad_output, bias, input, bias_2, input_2)
return tmp, tmp, tmp2, tmp2


def fused_bias_geglu(input, bias, input_2, bias_2):
args = _cast_if_autocast_enabled(input, bias, input_2, bias_2)
with torch.cuda.amp.autocast(enabled=False):
return GeGLUFunction.apply(*args)
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_gelu_fusion,
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
Expand Down Expand Up @@ -451,7 +451,7 @@ def __init__(
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_gelu_fusion,
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_decoder_model(
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_decoder_model(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_decoder_model(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_encoder_model(
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -120,7 +120,7 @@ def get_encoder_model(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down Expand Up @@ -154,7 +154,7 @@ def get_encoder_model(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_gelu_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_gelu_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down Expand Up @@ -255,7 +255,7 @@ def __init__(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_gelu_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down Expand Up @@ -323,7 +323,7 @@ def __init__(
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
persist_layer_norm=False,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
openai_gelu=False,
Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down
38 changes: 23 additions & 15 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
bias_dropout_add_fused_train,
dropout_add,
)
from nemo.collections.nlp.modules.common.megatron.fused_bias_geglu import fused_bias_geglu
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm
from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
hidden_size,
ffn_hidden_size,
use_cpu_initialization=False,
bias_gelu_fusion=True,
bias_activation_fusion=True,
openai_gelu=False,
onnx_safe=False,
activation='gelu',
Expand Down Expand Up @@ -157,13 +158,12 @@ def __init__(
use_cpu_initialization=use_cpu_initialization,
bias=bias,
)
glu_activation_family = True
else:
glu_activation_family = False

if glu_activation_family and bias_gelu_fusion:
glu_activation_family = activation in ['reglu', 'swiglu']

if glu_activation_family and bias_activation_fusion:
raise ValueError(
f"Cannot use bias_gelu_fusion with {activation} activation. Please turn bias gelu fusion off."
f"Cannot use bias_activation_fusion with {activation} activation. Please turn bias gelu fusion off."
)

if glu_activation_family and openai_gelu:
Expand All @@ -176,14 +176,14 @@ def __init__(
f"Cannot use onnx_safe with specificed activation function : {activation} Please turn onnx safe off."
)

if bias_gelu_fusion and not bias:
if bias_activation_fusion and not bias:
raise ValueError(
f"Cannot use bias_gelu_fusion without bias terms. Please set bias=True or bias_gelu_fusion=False."
f"Cannot use bias_activation_fusion without bias terms. Please set bias=True or bias_activation_fusion=False."
)
else:
glu_activation_family = False

self.bias_gelu_fusion = bias_gelu_fusion
self.bias_activation_fusion = bias_activation_fusion

if activation in ["gelu", "geglu"]:
self.activation_func = F.gelu
Expand Down Expand Up @@ -226,15 +226,23 @@ def forward(self, hidden_states):

if self.activation in ['geglu', 'reglu', 'swiglu']:
intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states)

if self.bias_activation_fusion:
if self.activation == 'gelu':
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = fused_bias_geglu(
intermediate_parallel, bias_parallel, intermediate_parallel_2, bias_parallel_2
)

elif self.activation in ['geglu', 'reglu', 'swiglu']:
if bias_parallel is not None:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
)
else:
intermediate_parallel = self.activation_func(intermediate_parallel) * intermediate_parallel_2

elif self.bias_gelu_fusion and self.activation == 'gelu':
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
Expand Down Expand Up @@ -952,7 +960,7 @@ def __init__(
bias_dropout_fusion=True,
persist_layer_norm=False,
use_cpu_initialization=False,
bias_gelu_fusion=True,
bias_activation_fusion=True,
openai_gelu=False,
onnx_safe=False,
masked_softmax_fusion=True,
Expand Down Expand Up @@ -1134,7 +1142,7 @@ def __init__(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
Expand Down Expand Up @@ -1426,7 +1434,7 @@ def __init__(
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
use_cpu_initialization=False,
bias_gelu_fusion=True,
bias_activation_fusion=True,
bias_dropout_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=False,
Expand Down Expand Up @@ -1498,7 +1506,7 @@ def build_layer(layer_number):
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_fusion=bias_dropout_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
Expand Down

0 comments on commit ff72f90

Please sign in to comment.