Skip to content

Commit

Permalink
support rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Feb 26, 2024
1 parent 4d20b69 commit 1afd9de
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 51 deletions.
12 changes: 8 additions & 4 deletions wenet/transformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
Expand All @@ -29,7 +31,8 @@ def __init__(self,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
bias: bool = True,
eps: float = 1e-5):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
Expand Down Expand Up @@ -68,13 +71,14 @@ def __init__(self,
bias=bias,
)

assert norm in ['batch_norm', 'layer_norm']
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
self.norm = WENET_NORM_CLASSES['batch_norm'](channels, eps=eps)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
# layer_norm or rms_norm
self.norm = WENET_NORM_CLASSES[norm](channels, eps=eps)

self.pointwise_conv2 = nn.Conv1d(
channels,
Expand Down
20 changes: 17 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
Expand Down Expand Up @@ -77,6 +78,8 @@ def __init__(
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -90,7 +93,8 @@ def __init__(
)

self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
eps=eps)
self.use_output_layer = use_output_layer
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim,
Expand Down Expand Up @@ -128,6 +132,8 @@ def __init__(
),
dropout_rate,
normalize_before,
layer_norm_type,
eps,
) for _ in range(self.num_blocks)
])

Expand Down Expand Up @@ -320,6 +326,8 @@ def __init__(
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):

super().__init__()
Expand All @@ -342,7 +350,10 @@ def __init__(
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
mlp_type=mlp_type,
bias=bias)
bias=bias,
layer_norm_type=layer_norm_type,
eps=eps,
)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -362,7 +373,10 @@ def __init__(
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
mlp_type=mlp_type,
bias=bias)
bias=bias,
layer_norm_type=layer_norm_type,
eps=eps,
)

def forward(
self,
Expand Down
10 changes: 7 additions & 3 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class DecoderLayer(nn.Module):
"""Single decoder layer module.
Expand Down Expand Up @@ -46,16 +48,18 @@ def __init__(
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
"""Construct an DecoderLayer object."""
super().__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.norm3 = nn.LayerNorm(size, eps=1e-5)
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=eps)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=eps)
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=eps)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before

Expand Down
95 changes: 62 additions & 33 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
Expand Down Expand Up @@ -55,6 +56,8 @@ def __init__(
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
"""
Args:
Expand Down Expand Up @@ -101,7 +104,9 @@ def __init__(
)

self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
eps=eps)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
Expand Down Expand Up @@ -341,28 +346,32 @@ def forward_chunk_by_chunk(
class TransformerEncoder(BaseEncoder):
"""Transformer encoder module."""

def __init__(self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
Expand All @@ -373,7 +382,7 @@ def __init__(self,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
use_sdpa, layer_norm_type, eps)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
Expand All @@ -392,6 +401,8 @@ def __init__(self,
bias=bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
eps=eps,
) for _ in range(num_blocks)
])

Expand Down Expand Up @@ -429,6 +440,8 @@ def __init__(
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
"""Construct ConformerEncoder
Expand All @@ -447,13 +460,27 @@ def __init__(
causal (bool): whether to use causal convolution or not.
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
super().__init__(
input_size,
output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
attention_dropout_rate,
input_layer,
pos_enc_layer_type,
normalize_before,
static_chunk_size,
use_dynamic_chunk,
global_cmvn,
use_dynamic_left_chunk,
gradient_checkpointing,
use_sdpa,
layer_norm_type,
eps,
)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand Down Expand Up @@ -485,9 +512,11 @@ def __init__(
*encoder_selfattn_layer_args),
mlp_class(*positionwise_layer_args),
mlp_class(*positionwise_layer_args) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args) if use_cnn_module else None,
ConvolutionModule(*convolution_layer_args, eps=eps)
if use_cnn_module else None,
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
eps=eps,
) for _ in range(num_blocks)
])
25 changes: 17 additions & 8 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class TransformerEncoderLayer(nn.Module):
"""Encoder layer module.
Expand All @@ -44,13 +46,16 @@ def __init__(
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=eps)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=eps)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down Expand Up @@ -135,24 +140,28 @@ def __init__(
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
self.norm_ff = nn.LayerNorm(size, eps=eps) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=eps) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type](size,
eps=eps)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
size, eps=1e-5) # for the final output of the block
self.norm_conv = WENET_NORM_CLASSES[layer_norm_type](
size, eps=eps) # for the CNN module
self.norm_final = WENET_NORM_CLASSES[layer_norm_type](
size, eps=eps) # for the CNN module
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down
22 changes: 22 additions & 0 deletions wenet/transformer/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch


class RMSNorm(torch.nn.Module):
""" https://arxiv.org/pdf/1910.07467.pdf
"""

def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
x = self._norm(x.float()).type_as(x)
return x * self.weight
Loading

0 comments on commit 1afd9de

Please sign in to comment.