From fcd25bdfffb17b64a0c9d98250ae6021338e573f Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 6 Feb 2022 18:22:56 +0800 Subject: [PATCH 001/185] Fix torch.nn.Embedding error for torch below 1.8.0 --- egs/librispeech/ASR/transducer/beam_search.py | 4 +++- egs/librispeech/ASR/transducer/model.py | 1 + egs/librispeech/ASR/transducer_lstm/beam_search.py | 4 +++- egs/librispeech/ASR/transducer_lstm/model.py | 1 + egs/librispeech/ASR/transducer_stateless/beam_search.py | 2 +- egs/librispeech/ASR/transducer_stateless/model.py | 1 + 6 files changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index f45d06ce9c..11032f31ae 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index fa0b2dd680..8305248c9c 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -99,6 +99,7 @@ def forward( sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index dfc22fcf87..3531a96334 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a28..31843b60ef 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -101,6 +101,7 @@ def forward( sos_y = add_sos(y, sos_id=sos_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 341c74fab3..1cce482352 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -48,7 +48,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 7aac290d98..17b5f63e58 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -93,6 +93,7 @@ def forward( sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out = self.decoder(sos_y_padded) From 8f8ec223a715776f8e92a6daaa082627deae3cc8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Feb 2022 21:18:40 +0800 Subject: [PATCH 002/185] Changes to fbank computation, use lilcom chunky writer --- egs/librispeech/ASR/local/compute_fbank_librispeech.py | 4 ++-- egs/librispeech/ASR/local/compute_fbank_musan.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index b26034eb20..5c33ff8bef 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -28,7 +28,7 @@ from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -85,7 +85,7 @@ def compute_fbank_librispeech(): # when an executor is specified, make more partitions num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index d44524e70b..f5911746b9 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -82,7 +82,7 @@ def compute_fbank_musan(): storage_path=f"{output_dir}/feats_musan", num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) ) musan_cuts.to_json(musan_cuts_path) From 48a764eccf30e0d7a178563255a648d292a19673 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Feb 2022 21:19:37 +0800 Subject: [PATCH 003/185] Add min in q,k,v of attention --- .../ASR/transducer_stateless/conformer.py | 51 +++++++++++++++++-- .../ASR/transducer_stateless/decoder.py | 1 + 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9f..f803ee9b67 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -440,8 +440,19 @@ def __init__( ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + + self.in_proj_floor_scale = 10.0 # so it learns fast enough.. + with torch.no_grad(): + in_proj_floor = torch.Tensor(3 * embed_dim) + # key and query get a floor value quite close to zero. + in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale + # value gets very low floor, may be close to having no effectc. + in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale + self.in_proj_floor = nn.Parameter(in_proj_floor) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -526,6 +537,7 @@ def forward( key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale ) def rel_shift(self, x: Tensor) -> Tensor: @@ -570,6 +582,7 @@ def multi_head_attention_forward( key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + in_proj_floor: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -629,9 +642,12 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( + _qkv = nn.functional.linear( query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + ) + if in_proj_floor is not None: + _qkv = torch.maximum(_qkv, in_proj_floor) + q, k, v = _qkv.chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -643,6 +659,10 @@ def multi_head_attention_forward( if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + q = torch.maximum(q, _f) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -650,7 +670,11 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + _kv = nn.functional.linear(key, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + _kv = torch.maximum(_kv, _f) + k, v = _kv.chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -661,6 +685,10 @@ def multi_head_attention_forward( if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + q = torch.maximum(q, _f) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -670,6 +698,9 @@ def multi_head_attention_forward( if _b is not None: _b = _b[_start:_end] k = nn.functional.linear(key, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + k = torch.maximum(k, _f) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -679,6 +710,10 @@ def multi_head_attention_forward( if _b is not None: _b = _b[_start:] v = nn.functional.linear(value, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + v = torch.maximum(v, _f) + if attn_mask is not None: assert ( @@ -918,3 +953,13 @@ def forward(self, x: Tensor) -> Tensor: def identity(x): return x + + +if __name__ == '__main__': + feature_dim = 50 + c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c(torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64)) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index c2c6552a96..003b03a2e7 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -82,6 +82,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: Returns: Return a tensor of shape (N, U, embedding_dim). """ + y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) From a859dcb20504e7b4bbc2ea9b1f1b28ad5f5e0757 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Feb 2022 12:14:48 +0800 Subject: [PATCH 004/185] Remove learnable offset, use relu instead. --- .../ASR/transducer_stateless/conformer.py | 46 +++---------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index f803ee9b67..c063359050 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -440,19 +440,8 @@ def __init__( ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - - self.in_proj_floor_scale = 10.0 # so it learns fast enough.. - with torch.no_grad(): - in_proj_floor = torch.Tensor(3 * embed_dim) - # key and query get a floor value quite close to zero. - in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale - # value gets very low floor, may be close to having no effectc. - in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale - self.in_proj_floor = nn.Parameter(in_proj_floor) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -537,7 +526,6 @@ def forward( key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, - in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale ) def rel_shift(self, x: Tensor) -> Tensor: @@ -582,7 +570,6 @@ def multi_head_attention_forward( key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - in_proj_floor: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -642,12 +629,7 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - _qkv = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ) - if in_proj_floor is not None: - _qkv = torch.maximum(_qkv, in_proj_floor) - q, k, v = _qkv.chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -658,10 +640,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - q = torch.maximum(q, _f) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -670,11 +649,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - _kv = nn.functional.linear(key, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - _kv = torch.maximum(_kv, _f) - k, v = _kv.chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -684,10 +659,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - q = torch.maximum(q, _f) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -697,10 +669,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - k = torch.maximum(k, _f) + k = nn.functional.linear(key, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -709,10 +678,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - v = torch.maximum(v, _f) + v = nn.functional.linear(value, _w, _b).relu() if attn_mask is not None: From 3323cabf467324b5d8bc3b1247a37724cd778ed0 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Tue, 8 Feb 2022 14:25:31 +0800 Subject: [PATCH 005/185] Experiments based on SpecAugment change --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 213 +++++++++++++++++- 1 file changed, 211 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d038..e5fcc5893a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -28,7 +28,6 @@ K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -219,10 +218,11 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, + num_frame_masks=10, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, + max_frames_mask_fraction=0.4, ) ) else: @@ -383,3 +383,212 @@ def test_clean_cuts(self) -> CutSet: def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") + + +import math +import random +import numpy as np +from typing import Optional, Dict + +import torch + +from lhotse import CutSet + +class SpecAugment(torch.nn.Module): + """ + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 1, + features_mask_size: int = 13, + num_frame_masks: int = 1, + frames_mask_size: int = 70, + max_frames_mask_fraction: float = 0.2, + p=0.5, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks >= 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single(features[sequence_idx]) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + return features + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> torch.Tensor: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + return features + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp(features, factor=self.time_warp_factor) + if mask: + from torchaudio.functional import mask_along_axis + + mean = features.mean() + for _ in range(self.num_feature_masks): + features = mask_along_axis( + features.unsqueeze(0), + mask_param=self.features_mask_size, + mask_value=mean, + axis=2, + ).squeeze(0) + for _ in range(self.num_frame_masks): + _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) + max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) + + features = mask_along_axis( + features.unsqueeze(0), + mask_param=max_mask_frames, + mask_value=mean, + axis=1, + ).squeeze(0) + return features + + def state_dict(self) -> Dict: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = np.random.randint(factor + 1, t - factor) + warped = np.random.randint(center - factor, center + factor + 1) + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) \ No newline at end of file From beaf5bfbab85108f32751d5590fddc642437fdb7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Feb 2022 19:42:23 +0800 Subject: [PATCH 006/185] Merge specaug change from Mingshuang. --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 950a88a356..5c447bc4b5 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp", + default="transducer_stateless/exp-100h-relu-specaug", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From bd36216e8cbc40b194e02e0e8d5bb86a3e60edf2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Feb 2022 21:55:20 +0800 Subject: [PATCH 007/185] Use much more aggressive SpecAug setup --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e5fcc5893a..a5ab012e33 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -220,7 +220,7 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=10, features_mask_size=27, - num_feature_masks=2, + num_feature_masks=10, frames_mask_size=100, max_frames_mask_fraction=0.4, ) @@ -521,7 +521,7 @@ def _forward_single( _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) - + features = mask_along_axis( features.unsqueeze(0), mask_param=max_mask_frames, @@ -591,4 +591,4 @@ def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: mode="bicubic", align_corners=False, ) - return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) \ No newline at end of file + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) From dd19a6a2b13a7c452f5910fb4a0e123910540302 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Feb 2022 12:02:19 +0800 Subject: [PATCH 008/185] Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a5ab012e33..11b07bd69d 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -220,9 +220,9 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=10, features_mask_size=27, - num_feature_masks=10, + num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.4, + max_frames_mask_fraction=0.3, ) ) else: From 8aa50df4f0c5b6d1edb2e850364c32fd3c666aab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Feb 2022 22:52:53 +0800 Subject: [PATCH 009/185] Change p=0.5->0.9, mask_fraction 0.3->0.2 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 3 ++- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 11b07bd69d..7df7a35254 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -222,7 +222,8 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: features_mask_size=27, num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.3, + max_frames_mask_fraction=0.2, + p=0.9 ) ) else: diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 5c447bc4b5..136faca579 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaug", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From c170c53006a7822e06832e960b630bcae964893a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Feb 2022 14:59:14 +0800 Subject: [PATCH 010/185] Change p=0.9 to p=0.8 in SpecAug --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 7df7a35254..044ad4fc62 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -223,7 +223,7 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: num_feature_masks=2, frames_mask_size=100, max_frames_mask_fraction=0.2, - p=0.9 + p=0.8 ) ) else: diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 136faca579..62cd1e764e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 4cd2c02fffac5cba4b0ca02d414fecbda90f7104 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Feb 2022 15:53:11 +0800 Subject: [PATCH 011/185] Fix num_time_masks code; revert 0.8 to 0.9 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 11 +++++------ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 044ad4fc62..df2e484217 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -223,7 +223,7 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: num_feature_masks=2, frames_mask_size=100, max_frames_mask_fraction=0.2, - p=0.8 + p=0.9 ) ) else: @@ -518,11 +518,10 @@ def _forward_single( mask_value=mean, axis=2, ).squeeze(0) - for _ in range(self.num_frame_masks): - _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) - num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) - max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) - + _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) + max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) + for _ in range(num_frame_masks): features = mask_along_axis( features.unsqueeze(0), mask_param=max_mask_frames, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 62cd1e764e..4bd85ca2ec 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From d187ad8b739b4df3dbd1940b768393f0eed91a8e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Feb 2022 16:24:17 +0800 Subject: [PATCH 012/185] Change max_frames from 0.2 to 0.15 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index df2e484217..c1b16bcf09 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -222,7 +222,7 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: features_mask_size=27, num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.2, + max_frames_mask_fraction=0.15, p=0.9 ) ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 4bd85ca2ec..dccf9b99be 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2af1b3af981d9ede788f0a16d6032dc4d55a6ed9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Feb 2022 19:39:19 +0800 Subject: [PATCH 013/185] Remove ReLU in attention --- .../ASR/transducer_stateless/conformer.py | 12 ++++++------ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index c063359050..4627dd147a 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -629,7 +629,7 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -640,7 +640,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b).relu() + q = nn.functional.linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -649,7 +649,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -659,7 +659,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b).relu() + q = nn.functional.linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -669,7 +669,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b).relu() + k = nn.functional.linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -678,7 +678,7 @@ def multi_head_attention_forward( _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b).relu() + v = nn.functional.linear(value, _w, _b) if attn_mask is not None: diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index dccf9b99be..7d1d7ff089 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix", + default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 581786a6d367e7d9313c43ae12030bc6044c9d0c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 13:44:43 +0800 Subject: [PATCH 014/185] Adding diagnostics code... --- .../ASR/transducer_stateless/diagnostics.py | 284 ++++++++++++++++++ .../ASR/transducer_stateless/train.py | 40 ++- 2 files changed, 313 insertions(+), 11 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_stateless/diagnostics.py diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py new file mode 100644 index 0000000000..2dff918058 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -0,0 +1,284 @@ +import torch +from torch import Tensor +from torch import nn +import math +import random +from typing import Tuple, List + + +class TensorDiagnosticOptions(object): + """ + Options object for tensor diagnostics: + + Args: + memory_limit: the maximum number of bytes per tensor (limits how many copies + of the tensor we cache). + + """ + def __init__(self, memory_limit: int, + print_pos_ratio: bool = True): + self.memory_limit = memory_limit + self.print_pos_ratio = print_pos_ratio + + def dim_is_summarized(self, size: int): + return size > 10 and size != 31 + + def stats_types(self): + if self.print_pos_ratio: + return ["mean-abs", "pos-ratio"] + else: + return ["mean-abs"] + + + +def get_sum_abs_stats(x: Tensor, dim: int, + stats_type: str) -> Tuple[Tensor, int]: + """ + Returns the sum-of-absolute-value of this Tensor, for each + index into the specified axis/dim of the tensor. + Args: + x: Tensor, tensor to be analyzed + dim: dimension with 0 <= dim < x.ndim + stats_type: either "mean-abs" in which case the stats represent the + mean absolute value, or "pos-ratio" in which case the + stats represent the proportion of positive values (actually: + the tensor is count of positive values, count is the count of + all values). + Returns (sum_abs, count) + where sum_abs is a Tensor of shape (x.shape[dim],), and the count + is an integer saying how many items were counted in each element + of sum_abs. + """ + if stats_type == "mean-abs": + x = x.abs() + else: + assert stats_type == "pos-ratio" + x = (x > 0).to(dtype=torch.float) + orig_numel = x.numel() + sum_dims = [ d for d in range(x.ndim) if d != dim ] + x = torch.sum(x, dim=sum_dims) + count = orig_numel // x.numel() + x = x.flatten() + return x, count + +def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], + options: TensorDiagnosticOptions, + sizes_same: bool, + stats_type: str): + """ + This function gets diagnostics for a dimension of a module. + Args: + dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim + options: options object + sizes_same: true if all the tensor sizes are the same on this dimension + stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats + we accumulate, mean-abs is mean absolute value, "pos-ratio" + is proportion of positive to nonnegative values. + Returns: + Diagnostic as a string, either percentiles or the actual values, + see the code. + """ + # stats_and_counts is a list of pair (Tensor, int) + stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] + stats = [ x[0] for x in stats_and_counts ] + counts = [ x[1] for x in stats_and_counts ] + if sizes_same: + stats = torch.stack(stats).sum(dim=0) + count = sum(counts) + stats = stats / count + else: + stats = [ x[0] / x[1] for x in stats_and_counts ] + stats = torch.cat(stats, dim=0) + # if `summarize` we print percentiles of the stats; else, + # we print out individual elements. + summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) + if summarize: + # print out percentiles. + stats = stats.sort()[0] + num_percentiles = 10 + size = stats.numel() + percentiles = [] + for i in range(num_percentiles + 1): + index = (i * (size - 1)) // num_percentiles + percentiles.append(stats[index].item()) + percentiles = [ '%.2g' % x for x in percentiles ] + percentiles = ' '.join(percentiles) + return f'percentiles: [{percentiles}]' + else: + stats = stats.tolist() + stats = [ '%.2g' % x for x in stats ] + stats = '[' + ' '.join(stats) + ']' + return stats + + + +def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], + options: TensorDiagnosticOptions): + + for stats_type in options.stats_types(): + # stats_type will be "mean-abs" or "pos-ratio". + sizes = [ x.shape[dim] for x in tensors ] + sizes_same = all([ x == sizes[0] for x in sizes ]) + s = get_diagnostics_for_dim(dim, tensors, + options, sizes_same, + stats_type) + + min_size = min(sizes) + max_size = max(sizes) + size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" + # stats_type will be "mean-abs" or "pos-ratio". + print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") + + +class TensorDiagnostic(object): + """ + This class is not directly used by the user, it is responsible for collecting + diagnostics for a single parameter tensor of a torch.Module. + """ + def __init__(self, + opts: TensorDiagnosticOptions, + name: str): + self.name = name + self.opts = opts + self.saved_tensors = [] + + def accumulate(self, x): + if isinstance(x, Tuple): + x = x[0] + if not isinstance(x, Tensor): + return + if x.device == torch.device('cpu'): + x = x.detach().clone() + else: + x = x.detach().to('cpu', non_blocking=True) + self.saved_tensors.append(x) + l = len(self.saved_tensors) + if l & (l - 1) == 0: # power of 2.. + self._limit_memory() + + def _limit_memory(self): + if len(self.saved_tensors) > 1024: + self.saved_tensors = self.saved_tensors[-1024:] + return + + tot_mem = 0.0 + for i in reversed(range(len(self.saved_tensors))): + tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() + if tot_mem > self.opts.memory_limit: + self.saved_tensors = self.saved_tensors[i:] + return + + def print_diagnostics(self): + if len(self.saved_tensors) == 0: + print("{name}: no stats".format(name=self.name)) + return + if self.saved_tensors[0].ndim == 0: + # ensure there is at least one dim. + self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] + + ndim = self.saved_tensors[0].ndim + for dim in range(ndim): + print_diagnostics_for_dim(self.name, dim, + self.saved_tensors, + self.opts) + + +class ModelDiagnostic(object): + def __init__(self, opts: TensorDiagnosticOptions): + self.diagnostics = dict() + self.opts = opts + + def __getitem__(self, name: str): + if name not in self.diagnostics: + self.diagnostics[name] = TensorDiagnostic(self.opts, name) + return self.diagnostics[name] + + def print_diagnostics(self): + for k in sorted(self.diagnostics.keys()): + self.diagnostics[k].print_diagnostics() + + + +def attach_diagnostics(model: nn.Module, + opts: TensorDiagnosticOptions) -> ModelDiagnostic: + ans = ModelDiagnostic(opts) + for name, module in model.named_modules(): + if name == '': + name = "" + forward_diagnostic = TensorDiagnostic(opts, name + ".output") + backward_diagnostic = TensorDiagnostic(opts, name + ".grad") + + + # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, + # ensures that we use the current values. (matters for name, since + # the variable gets overwritten). these closures don't really capture + # by value, only by "the final value the variable got in the function" :-( + def forward_hook(_module, _input, _output, + _model_diagnostic=ans, _name=name): + if isinstance(_output, Tensor): + _model_diagnostic[f"{_name}.output"].accumulate(_output) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) + + def backward_hook(_module, _input, _output, + _model_diagnostic=ans, _name=name): + if isinstance(_output, Tensor): + _model_diagnostic[f"{_name}.grad"].accumulate(_output) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) + + module.register_forward_hook(forward_hook) + module.register_backward_hook(backward_hook) + + for name, parameter in model.named_parameters(): + + def param_backward_hook(grad, + _parameter=parameter, + _model_diagnostic=ans, + _name=name): + _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) + _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) + + parameter.register_hook(param_backward_hook) + return ans + + + +def _test_tensor_diagnostic(): + opts = TensorDiagnosticOptions(2**20, True) + + diagnostic = TensorDiagnostic(opts, "foo") + + for _ in range(10): + diagnostic.accumulate(torch.randn(50, 100) * 10.0) + + diagnostic.print_diagnostics() + + model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) + + diagnostic = attach_diagnostics(model, opts) + for _ in range(10): + T = random.randint(200, 300) + x = torch.randn(T, 100) + y = model(x) + y.sum().backward() + + diagnostic.print_diagnostics() + + + +if __name__ == '__main__': + _test_tensor_diagnostic() + + +def _test_func(): + ans = [] + for i in range(10): + x = list() + x.append(i) + def func(): + return x + ans.append(func) + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 7d1d7ff089..0e1bbeaffb 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ import argparse import logging +import diagnostics # ./diagnostics.py from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -109,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix", + default="transducer_stateless/specaugmod_baseline", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -138,6 +139,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + return parser @@ -487,6 +495,9 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if params.print_diagnostics and batch_idx == 5: + return + if batch_idx % params.log_interval == 0: logging.info( @@ -494,9 +505,6 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - - if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -599,6 +607,11 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() @@ -626,13 +639,14 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) @@ -660,6 +674,10 @@ def remove_short_and_long_utt(c: Cut): world_size=world_size, ) + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + save_checkpoint( params=params, model=model, From 63d8d935d43b719a74bdaa5db3892e71a2b9fe69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 13:56:15 +0800 Subject: [PATCH 015/185] Refactor/simplify ConformerEncoder --- .../ASR/transducer_stateless/conformer.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 4627dd147a..07b80076dd 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import copy import math import warnings from typing import Optional, Tuple @@ -264,13 +264,12 @@ def forward( return src -class ConformerEncoder(nn.TransformerEncoder): +class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -281,11 +280,12 @@ class ConformerEncoder(nn.TransformerEncoder): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + self, encoder_layer: nn.Module, num_layers: int ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) + super(ConformerEncoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.num_layers = num_layers + def forward( self, @@ -320,9 +320,6 @@ def forward( src_key_padding_mask=src_key_padding_mask, ) - if self.norm is not None: - output = self.norm(output) - return output From c1063def9552fd3af9a6d54b304a9cc6939a8b93 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 17:34:58 +0800 Subject: [PATCH 016/185] First version of rand-combine iterated-training-like idea. --- .../ASR/transducer_stateless/conformer.py | 224 +++++++++++++++++- .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 219 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 07b80076dd..327849485c 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Sequence import torch from torch import Tensor, nn @@ -56,6 +56,7 @@ def __init__( cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, + aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -80,10 +81,11 @@ def __init__( cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, + aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.normalize_before = normalize_before if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) + self.after_norm = nn.LayerNorm(d_model) # TODO: remove. else: # Note: TorchScript detects that self.after_norm could be used inside forward() # and throws an error without this change. @@ -280,12 +282,21 @@ class ConformerEncoder(nn.Module): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int + self, encoder_layer: nn.Module, + num_layers: int, + aux_layers: Sequence[int], ) -> None: super(ConformerEncoder, self).__init__() self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.aux_layers = set(aux_layers + [num_layers - 1]) + assert num_layers - 1 not in aux_layers self.num_layers = num_layers - + num_channels = encoder_layer.norm_final.weight.numel() + self.combiner = RandomCombine(num_inputs=len(self.aux_layers), + num_channels=num_channels, + final_weight=0.5, + pure_prob=0.333, + stddev=2.0) def forward( self, @@ -312,14 +323,19 @@ def forward( """ output = src - for mod in self.layers: + outputs = [] + + for i, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) + if i in self.aux_layers: + outputs.append(output) + output = self.combiner(outputs) return output @@ -918,7 +934,203 @@ def identity(x): return x +class RandomCombine(torch.nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + All but the last input will have a linear transform before we + randomly combine them; these linear transforms will be initialzed + to the identity transform. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + def __init__(self, num_inputs: int, + num_channels: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0) -> None: + """ + Args: + num_inputs: The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + num_channels: The number of channels on the input, e.g. 512. + final_weight: The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, + or combinations of layers, to use, is conceptually as follows. + With probability `pure_prob`: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super(RandomCombine, self).__init__() + assert pure_prob >= 0 and pure_prob <= 1 + assert final_weight > 0 and final_weight < 1 + assert num_inputs >= 1 + self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True) + for _ in range(num_inputs - 1)]) + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev= stddev + + self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() + self._reset_parameters() + + def _reset_parameters(self): + for i in range(len(self.linear)): + nn.init.eye_(self.linear[i].weight) + nn.init.constant_(self.linear[i].bias, 0.0) + + def forward(self, inputs: Sequence[Tensor]) -> Tensor: + """ + Forward function. + Args: + inputs: a list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + a Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training: + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + mod_inputs = [] + for i in range(num_inputs - 1): + mod_inputs.append(self.linear[i](inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) + + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, + num_frames) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + + def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: + """ + Return a tensor of random weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), such that + ans.sum(dim=1) is all ones. + + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) + + def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with + exactly one weight equal to 1.0 on each frame. + """ + + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + + indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, + final, nonfinal) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) + return ans + + + def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that + sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. + """ + logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev + logprobs[:,-1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") + num_inputs = 3 + num_channels = 50 + m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, + final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) + + x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + if __name__ == '__main__': + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 0e1bbeaffb..8877d4e759 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline", + default="transducer_stateless/specaugmod_baseline_randcombine1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2ff520c8004eb7cfd43585ae840ca0fcb5bbcfae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Feb 2022 12:15:56 +0800 Subject: [PATCH 017/185] Improvements to diagnostics (RE those with 1 dim --- .../ASR/transducer_stateless/diagnostics.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 2dff918058..088ef14cb9 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -31,32 +31,34 @@ def stats_types(self): -def get_sum_abs_stats(x: Tensor, dim: int, +def get_tensor_stats(x: Tensor, dim: int, stats_type: str) -> Tuple[Tensor, int]: """ - Returns the sum-of-absolute-value of this Tensor, for each - index into the specified axis/dim of the tensor. + Returns the specified transformation of the Tensor (either x or x.abs() + or (x > 0), summed over all but the index `dim`. + Args: x: Tensor, tensor to be analyzed dim: dimension with 0 <= dim < x.ndim - stats_type: either "mean-abs" in which case the stats represent the - mean absolute value, or "pos-ratio" in which case the - stats represent the proportion of positive values (actually: - the tensor is count of positive values, count is the count of - all values). - Returns (sum_abs, count) - where sum_abs is a Tensor of shape (x.shape[dim],), and the count + stats_type: + "mean-abs" or "abs-value" -> take abs() before summing + "pos-ratio" -> take (x > 0) before summing + "value -> just sum x itself + Returns (stats, count) + where stats is a Tensor of shape (x.shape[dim],), and the count is an integer saying how many items were counted in each element - of sum_abs. + of stats. """ - if stats_type == "mean-abs": + if stats_type == "mean-abs" or stats_type == "abs-value": x = x.abs() - else: - assert stats_type == "pos-ratio" + elif stats_type == "pos-ratio": x = (x > 0).to(dtype=torch.float) + else: + assert stats_type == "value" orig_numel = x.numel() sum_dims = [ d for d in range(x.ndim) if d != dim ] - x = torch.sum(x, dim=sum_dims) + if len(sum_dims) > 0: + x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() return x, count @@ -79,7 +81,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], see the code. """ # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] + stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] if sizes_same: @@ -114,9 +116,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions): + ndim = tensors[0].ndim + # options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the + # normal case. + stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] - for stats_type in options.stats_types(): - # stats_type will be "mean-abs" or "pos-ratio". + for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] sizes_same = all([ x == sizes[0] for x in sizes ]) s = get_diagnostics_for_dim(dim, tensors, From 9d1b4ae04682d12aef2fecb902f318fcf9cab716 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 16:33:27 +0800 Subject: [PATCH 018/185] Add pelu to this good-performing setup.. --- .../ASR/conformer_ctc/subsampling.py | 38 ++++++++++++++++++- .../ASR/transducer_stateless/conformer.py | 17 +++------ .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364e..b230719260 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -45,13 +45,14 @@ def __init__(self, idim: int, odim: int) -> None: nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - nn.ReLU(), + PeLU(cutoff=-1.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - nn.ReLU(), + PeLU(cutoff=-5.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -70,6 +71,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) return x @@ -159,3 +161,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) return x + + +class PeLUFunction(torch.autograd.Function): + """ + Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)). + The function is: + x.relu() + alpha * (cutoff - x).relu() + E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off + of neurons. + """ + @staticmethod + def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor: + mask1 = (x >= 0) # >=, so there is deriv if x == 0. + p = cutoff - x + mask2 = (p >= 0) + ctx.save_for_backward(mask1, mask2) + ctx.alpha = alpha + return x.relu() + alpha * p.relu() + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]: + mask1, mask2 = ctx.saved_tensors + return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None + + + +class PeLU(torch.nn.Module): + def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None: + super(PeLU, self).__init__() + self.cutoff = cutoff + self.alpha = alpha + def forward(self, x: Tensor) -> Tensor: + return PeLUFunction.apply(x, self.cutoff, self.alpha) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 327849485c..066232a02e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,6 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence +from subsampling import PeLU import torch from torch import Tensor, nn @@ -84,12 +85,7 @@ def __init__( self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) # TODO: remove. - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity + def forward( self, x: torch.Tensor, x_lens: torch.Tensor @@ -118,9 +114,6 @@ def forward( x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) - if self.normalize_before: - x = self.after_norm(x) - logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -163,14 +156,14 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), + PeLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), + PeLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -889,7 +882,7 @@ def __init__( padding=0, bias=bias, ) - self.activation = Swish() + self.activation = PeLU() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 8877d4e759..88b3662452 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1", + default="transducer_stateless/specaugmod_baseline_randcombine1_pelu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 9ed7d55a846047373a24f0c084bf7f325e9cbe95 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 16:34:55 +0800 Subject: [PATCH 019/185] Small bug fixes/imports --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index b230719260..c97f1ef486 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from torch import Tensor +from typing import Tuple class Conv2dSubsampling(nn.Module): From 3fb559d2f02402af91707fe0df633bcca497fc4d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 18:27:08 +0800 Subject: [PATCH 020/185] Add baseline for the PeLU expt, keeping only the small normalization-related changes. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index c97f1ef486..0e5e2d3de0 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,11 +47,11 @@ def __init__(self, idim: int, odim: int) -> None: nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - PeLU(cutoff=-1.0), + nn.ReLU(), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - PeLU(cutoff=-5.0), + nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 066232a02e..2b97047cf5 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,14 +156,14 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - PeLU(), + Swish(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - PeLU(), + Swish(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -882,7 +882,7 @@ def __init__( padding=0, bias=bias, ) - self.activation = PeLU() + self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 88b3662452..283aaecdd4 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_pelu", + default="transducer_stateless/specaugmod_baseline_randcombine1_pelu_base", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5c177fc52b551c188bbb828cad1d13450553aca6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Mar 2022 23:52:03 +0800 Subject: [PATCH 021/185] pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 11 +++++++++++ egs/librispeech/ASR/transducer_stateless/conformer.py | 4 +++- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 0e5e2d3de0..73493a7ea7 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,10 +48,12 @@ def __init__(self, idim: int, odim: int) -> None: in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), + ExpScale(odim, 1, 1, speed=2.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), + ExpScale(odim, 1, 1, speed=2.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) @@ -195,3 +197,12 @@ def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None: self.alpha = alpha def forward(self, x: Tensor) -> Tensor: return PeLUFunction.apply(x, self.cutoff, self.alpha) + +class ExpScale(torch.nn.Module): + def __init__(self, *shape, speed: float = 1.0): + super(ExpScale, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return x * (self.scale * self.speed).exp() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 2b97047cf5..3789e02fdf 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU +from subsampling import PeLU, ExpScale import torch from torch import Tensor, nn @@ -157,6 +157,7 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), + ExpScale(dim_feedforward, speed=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -164,6 +165,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), + ExpScale(dim_feedforward, speed=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 283aaecdd4..183a924c61 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_pelu_base", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 23b3aa233c792de86fb23b04f4a0160ba74f4d51 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 00:42:37 +0800 Subject: [PATCH 022/185] Double learning rate of exp-scale units --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 73493a7ea7..3b35c2ebef 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ def __init__(self, idim: int, odim: int) -> None: in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=2.0), + ExpScale(odim, 1, 1, speed=4.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=2.0), + ExpScale(odim, 1, 1, speed=4.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3789e02fdf..59f317e900 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -157,7 +157,7 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=2.0), + ExpScale(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -165,7 +165,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=2.0), + ExpScale(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 183a924c61..a1ded87c64 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From bc6c720e257c0b586ca2257d5be14b5358012bc1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 10:52:05 +0800 Subject: [PATCH 023/185] Combine ExpScale and swish for memory reduction --- .../ASR/conformer_ctc/subsampling.py | 67 +++++++++++++++++++ .../ASR/transducer_stateless/conformer.py | 5 +- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 3b35c2ebef..600156bf14 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -206,3 +206,70 @@ def __init__(self, *shape, speed: float = 1.0): def forward(self, x: Tensor) -> Tensor: return x * (self.scale * self.speed).exp() + + + +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * torch.sigmoid(x)) * (scale * speed).exp() + + +def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * torch.sigmoid(x)) * (scale * speed).exp() + + +class ExpScaleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x, scale) + ctx.speed = speed + return _exp_scale_swish(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_swish(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + + +class ExpScaleSwish(torch.nn.Module): + # combines ExpScale an Swish + # caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) + def __init__(self, *shape, speed: float = 1.0): + super(ExpScaleSwish, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return ExpScaleSwishFunction.apply(x, self.scale, self.speed) + # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() + # return x * (self.scale * self.speed).exp() + +def _test_exp_scale_swish(): + class Swish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + x1 = torch.randn(50, 60).detach() + x2 = x1.detach() + + m1 = ExpScaleSwish(50, 1, speed=4.0) + m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0)) + x1.requires_grad = True + x2.requires_grad = True + + y1 = m1(x1) + y2 = m2(x2) + assert torch.allclose(y1, y2) + y1.sum().backward() + y2.sum().backward() + assert torch.allclose(x1.grad, x2.grad) + + + +if __name__ == '__main__': + _test_exp_scale_swish() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 59f317e900..3386ed9b26 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,8 +156,7 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), - ExpScale(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -165,7 +164,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) From cd216f50b63e92f8bdce493428b553570615ead1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 11:03:01 +0800 Subject: [PATCH 024/185] Add import --- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3386ed9b26..83e0f8bcac 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale +from subsampling import PeLU, ExpScale, ExpScaleSwish import torch from torch import Tensor, nn From 3d9ddc201680747cab89838d9abe9797225f0128 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 12:29:44 +0800 Subject: [PATCH 025/185] Fix backprop bug --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 600156bf14..a66421adfe 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -220,7 +220,7 @@ def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: class ExpScaleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x, scale) + ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed return _exp_scale_swish(x, scale, speed) From 503f8d521ce10d24e5fc1b62760c630843cf80c2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 13:08:56 +0800 Subject: [PATCH 026/185] Fix bug in diagnostics --- egs/librispeech/ASR/transducer_stateless/diagnostics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 2dff918058..1a2324775c 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -56,7 +56,8 @@ def get_sum_abs_stats(x: Tensor, dim: int, x = (x > 0).to(dtype=torch.float) orig_numel = x.numel() sum_dims = [ d for d in range(x.ndim) if d != dim ] - x = torch.sum(x, dim=sum_dims) + if len(sum_dims) != 0: + x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() return x, count From 3207bd98a942f96e4a052d9984ea8ee0040f2269 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 13:16:40 +0800 Subject: [PATCH 027/185] Increase scale on Scale from 4 to 20 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index a66421adfe..e38a94d098 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ def __init__(self, idim: int, odim: int) -> None: in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=4.0), + ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=4.0), + ExpScale(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 83e0f8bcac..6907feb263 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,7 +156,7 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -164,7 +164,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScaleSwish(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a1ded87c64..c57968428f 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 7e889996413bc757c2cc7160c6e96644467ab57e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 14:31:29 +0800 Subject: [PATCH 028/185] Increase scale from 20 to 50. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index e38a94d098..97b9ae97be 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ def __init__(self, idim: int, odim: int) -> None: in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScale(odim, 1, 1, speed=50.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScale(odim, 1, 1, speed=50.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 6907feb263..ef6b4ac973 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,7 +156,7 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=50.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -164,7 +164,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=50.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c57968428f..980633ed66 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale4", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 9cc5999829ef1441d99804204d1f61a796bc4948 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 15:50:51 +0800 Subject: [PATCH 029/185] Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module --- egs/librispeech/ASR/transducer_stateless/conformer.py | 9 +++------ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index ef6b4ac973..dc6b54399d 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -163,7 +163,6 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), ExpScaleSwish(dim_feedforward, speed=50.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -874,7 +873,9 @@ def __init__( groups=channels, bias=bias, ) - self.norm = nn.LayerNorm(channels) + # shape: (channels, 1), broadcasts with (batch, channel, time). + self.activation = ExpScaleSwish(channels, 1, speed=50.0) + self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -883,7 +884,6 @@ def __init__( padding=0, bias=bias, ) - self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. @@ -905,9 +905,6 @@ def forward(self, x: Tensor) -> Tensor: # 1D Depthwise Conv x = self.depthwise_conv(x) # x is (batch, channels, time) - x = x.permute(0, 2, 1) - x = self.norm(x) - x = x.permute(0, 2, 1) x = self.activation(x) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 980633ed66..973733d4bc 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale4", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From eb3ed5420249d1c70d48700d72837e1c8a646454 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 15:56:45 +0800 Subject: [PATCH 030/185] Reduce scale from 50 to 20 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 97b9ae97be..e38a94d098 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ def __init__(self, idim: int, odim: int) -> None: in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=50.0), + ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=50.0), + ExpScale(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index dc6b54399d..368165008a 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,14 +156,14 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=50.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=50.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -874,7 +874,7 @@ def __init__( bias=bias, ) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleSwish(channels, 1, speed=50.0) + self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, From 6252282fd02f0a105e718091dd321cc71e205a95 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 20:19:11 +0800 Subject: [PATCH 031/185] Add deriv-balancing code --- .../ASR/conformer_ctc/subsampling.py | 87 +++++++++++++++++++ .../ASR/transducer_stateless/conformer.py | 6 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index e38a94d098..aa842a31f9 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,11 +47,15 @@ def __init__(self, idim: int, odim: int) -> None: nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), + DerivBalancer(channel_dim=1, threshold=0.02, + max_factor=0.02), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), + DerivBalancer(channel_dim=1, threshold=0.02, + max_factor=0.02), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), ) @@ -248,6 +252,68 @@ def forward(self, x: Tensor) -> Tensor: # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() # return x * (self.scale * self.speed).exp() + + + +class DerivBalancerFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, channel_dim: int, + threshold: 0.05, max_factor: 0.05, + epsilon: 1.0e-10) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) + factor = (threshold - proportion_positive).relu() * (max_factor / threshold) + + ctx.save_for_backward(factor) + ctx.epsilon = epsilon + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + factor, = ctx.saved_tensors + neg_delta_grad = x_grad.abs() * factor + if ctx.epsilon != 0.0: + sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True) + deriv_is_zero = (sum_abs_grad == 0.0) + neg_delta_grad += ctx.epsilon * deriv_is_zero + + return x_grad - neg_delta_grad, None, None, None, None + + + +class DerivBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 0 at the threshold to those extremal values when none + of the inputs are positive. + + When all grads are zero for a channel, this + module sets all the input derivatives for that channel to -epsilon; the + idea is to bring completely dead neurons back to life this way. + """ + def __init__(self, channel_dim: int, + threshold: float = 0.05, + max_factor: float = 0.05, + epsilon: float = 1.0e-10): + super(DerivBalancer, self).__init__() + self.channel_dim = channel_dim + self.threshold = threshold + self.max_factor = max_factor + self.epsilon = epsilon + + def forward(self, x: Tensor) -> Tensor: + return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, + self.max_factor, self.epsilon) + + + def _test_exp_scale_swish(): class Swish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -271,5 +337,26 @@ def forward(self, x: Tensor) -> Tensor: +def _test_deriv_balancer(): + channel_dim = 0 + probs = torch.arange(0, 1, 0.01) + N = 500 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + y_grad[-1,:] = 0 + + y = m(x) + y.backward(gradient=y_grad) + print("x = ", x) + print("y grad = ", y_grad) + print("x grad = ", x.grad) + + + if __name__ == '__main__': + _test_deriv_balancer() _test_exp_scale_swish() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 368165008a..056958ff64 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish +from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer import torch from torch import Tensor, nn @@ -156,6 +156,8 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), + DerivBalancer(channel_dim=-1, threshold=0.02, + max_factor=0.02), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -163,6 +165,8 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), + DerivBalancer(channel_dim=-1, threshold=0.02, + max_factor=0.02), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 973733d4bc..6d6b3f240a 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 65b09dd5f22f72923289fd68c1641ecd33fa0c52 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 00:07:14 +0800 Subject: [PATCH 032/185] Double the threshold in brelu; slightly increase max_factor. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 8 ++++---- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index aa842a31f9..ba0f08271e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,15 +47,15 @@ def __init__(self, idim: int, odim: int) -> None: nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=1, threshold=0.05, + max_factor=0.025), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=1, threshold=0.05, + max_factor=0.025), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), ) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 056958ff64..42d159ff51 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,8 +156,8 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=-1, threshold=0.05, + max_factor=0.025), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -165,8 +165,8 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=-1, threshold=0.05, + max_factor=0.025), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), From 0cd14ae739ecfe9f01ebccf5f4b18cc7b9cbc8c0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 12:17:09 +0800 Subject: [PATCH 033/185] Fix exp dir --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 6d6b3f240a..eed89e6b97 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5f2c0a09b7eede63054ef20627cf298e2223734d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 16:28:24 +0800 Subject: [PATCH 034/185] Convert swish nonlinearities to ReLU --- .../ASR/conformer_ctc/subsampling.py | 78 ++++++++++++++++++- .../ASR/transducer_stateless/conformer.py | 11 ++- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 82 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ba0f08271e..a500e42a99 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -49,15 +49,13 @@ def __init__(self, idim: int, odim: int) -> None: ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025), - nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScaleRelu(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025), - nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) @@ -253,6 +251,60 @@ def forward(self, x: Tensor) -> Tensor: # return x * (self.scale * self.speed).exp() +def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * (scale * speed).exp()).relu() + + +class ExpScaleReluFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x.detach(), scale.detach()) + ctx.speed = speed + return _exp_scale_swish(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_swish(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + + + +class ExpScaleReluFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x.detach(), scale.detach()) + ctx.speed = speed + return _exp_scale_relu(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_relu(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + +class ExpScaleRelu(torch.nn.Module): + # combines ExpScale and Relu. + # caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0) + def __init__(self, *shape, speed: float = 1.0): + super(ExpScaleRelu, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return ExpScaleReluFunction.apply(x, self.scale, self.speed) + # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() + # return x * (self.scale * self.speed).exp() + + class DerivBalancerFunction(torch.autograd.Function): @@ -335,6 +387,23 @@ def forward(self, x: Tensor) -> Tensor: y2.sum().backward() assert torch.allclose(x1.grad, x2.grad) +def _test_exp_scale_relu(): + + x1 = torch.randn(50, 60).detach() + x2 = x1.detach() + + m1 = ExpScaleRelu(50, 1, speed=4.0) + m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0)) + x1.requires_grad = True + x2.requires_grad = True + + y1 = m1(x1) + y2 = m2(x2) + assert torch.allclose(y1, y2) + y1.sum().backward() + y2.sum().backward() + assert torch.allclose(x1.grad, x2.grad) + def _test_deriv_balancer(): @@ -360,3 +429,4 @@ def _test_deriv_balancer(): if __name__ == '__main__': _test_deriv_balancer() _test_exp_scale_swish() + _test_exp_scale_relu() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 42d159ff51..7af145a1e2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer +from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer import torch from torch import Tensor, nn @@ -158,7 +158,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleRelu(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -167,7 +167,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleRelu(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -877,8 +877,10 @@ def __init__( groups=channels, bias=bias, ) + self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, + max_factor=0.025) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleSwish(channels, 1, speed=20.0) + self.activation = ExpScaleRelu(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, @@ -910,6 +912,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.depthwise_conv(x) # x is (batch, channels, time) + x = self.balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index eed89e6b97..b1cb6d043e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 8a8b81cd181e209b7609d7e8d54467bfbe758271 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 22:21:42 +0800 Subject: [PATCH 035/185] Replace relu with swish-squared. --- .../ASR/conformer_ctc/subsampling.py | 18 ++++++++++-------- .../ASR/transducer_stateless/conformer.py | 6 +++--- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index a500e42a99..daf8fd251e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -212,12 +212,11 @@ def forward(self, x: Tensor) -> Tensor: def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * torch.sigmoid(x)) * (scale * speed).exp() - - -def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * torch.sigmoid(x)) * (scale * speed).exp() - + # double-swish! + x = (x * torch.sigmoid(x)) + x = (x * torch.sigmoid(x)) + x = x * (scale * speed).exp() + return x class ExpScaleSwishFunction(torch.autograd.Function): @staticmethod @@ -247,8 +246,11 @@ def __init__(self, *shape, speed: float = 1.0): def forward(self, x: Tensor) -> Tensor: return ExpScaleSwishFunction.apply(x, self.scale, self.speed) - # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() - # return x * (self.scale * self.speed).exp() + # x = (x * torch.sigmoid(x)) + # x = (x * torch.sigmoid(x)) + # x = x * (self.scale * self.speed).exp() + # return x + def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7af145a1e2..5adb7ca4ee 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -158,7 +158,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleRelu(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -167,7 +167,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleRelu(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -880,7 +880,7 @@ def __init__( self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleRelu(channels, 1, speed=20.0) + self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b1cb6d043e..a3eca26c9b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a37d98463aeaf0fd9370128cd0f03663bb3aaab1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Mar 2022 11:55:02 +0800 Subject: [PATCH 036/185] Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. --- .../ASR/conformer_ctc/subsampling.py | 5 ++--- .../ASR/transducer_stateless/conformer.py | 22 ++++++++++++++----- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index daf8fd251e..1fe1265fa2 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -212,9 +212,8 @@ def forward(self, x: Tensor) -> Tensor: def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - # double-swish! - x = (x * torch.sigmoid(x)) - x = (x * torch.sigmoid(x)) + # double-swish, implemented/approximated as offset-swish + x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 5adb7ca4ee..62d9f382fc 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -877,10 +877,10 @@ def __init__( groups=channels, bias=bias, ) - self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.025) - # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleSwish(channels, 1, speed=20.0) + + self.norm = nn.LayerNorm(channels) + # shape: (channels, 1), broadcasts with (batch, channel, time). + self.activation = SwishOffset() self.pointwise_conv2 = nn.Conv1d( channels, @@ -911,8 +911,10 @@ def forward(self, x: Tensor) -> Tensor: # 1D Depthwise Conv x = self.depthwise_conv(x) # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) - x = self.balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) @@ -927,6 +929,16 @@ def forward(self, x: Tensor) -> Tensor: """Return Swich activation function.""" return x * torch.sigmoid(x) +class SwishOffset(torch.nn.Module): + """Construct an SwishOffset object.""" + def __init__(self, offset: float = -1.0) -> None: + super(SwishOffset, self).__init__() + self.offset = offset + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x + self.offset) + def identity(x): return x diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a3eca26c9b..16746147fa 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From e2ace9d5457139dbc5a8092c9cc6afffab633857 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Mar 2022 11:24:04 +0800 Subject: [PATCH 037/185] Replace norm on input layer with scale of 0.1. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 3 +-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 1fe1265fa2..2df2678dd3 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -58,7 +58,6 @@ def __init__(self, idim: int, odim: int) -> None: ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -77,7 +76,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) + x = x * 0.1 return x diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 16746147fa..0dbd8479b7 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From d074cf73c6ba428f3667ffede22a336febb72fb1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Mar 2022 20:37:20 +0800 Subject: [PATCH 038/185] Extensions to diagnostics code --- .../ASR/transducer_stateless/diagnostics.py | 52 +++++++++++++++---- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 088ef14cb9..dfbc2dced5 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -25,7 +25,7 @@ def dim_is_summarized(self, size: int): def stats_types(self): if self.print_pos_ratio: - return ["mean-abs", "pos-ratio"] + return ["mean-abs", "pos-ratio", "value"] else: return ["mean-abs"] @@ -49,17 +49,23 @@ def get_tensor_stats(x: Tensor, dim: int, is an integer saying how many items were counted in each element of stats. """ - if stats_type == "mean-abs" or stats_type == "abs-value": + count = x.numel() // x.shape[dim] + + if stats_type == "eigs": + x = x.transpose(dim, -1) + x = x.reshape(-1, x.shape[-1]) + # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. + return torch.matmul(x.transpose(0, 1), x), count + elif stats_type == "mean-abs" or stats_type == "abs-value": x = x.abs() elif stats_type == "pos-ratio": x = (x > 0).to(dtype=torch.float) else: assert stats_type == "value" - orig_numel = x.numel() + sum_dims = [ d for d in range(x.ndim) if d != dim ] if len(sum_dims) > 0: x = torch.sum(x, dim=sum_dims) - count = orig_numel // x.numel() x = x.flatten() return x, count @@ -73,18 +79,35 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim options: options object sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats + stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value, + imdictates the type of stats we accumulate, mean-abs is mean absolute value, "pos-ratio" - is proportion of positive to nonnegative values. + is proportion of positive to nonnegative values, "eigs" + is eigenvalues after doing outer product on this dim, sum + over all other dimes. Returns: Diagnostic as a string, either percentiles or the actual values, - see the code. + see the code. Will return the empty string if the diagnostics did + not make sense to print out for this dimension, e.g. dimension + mismatch and stats_type == "eigs" """ # stats_and_counts is a list of pair (Tensor, int) + if tensors[0].shape[dim] > 512 and stats_type == 'eigs': + return '' # won't produce eigs stats if dim too large. stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] - if sizes_same: + + if stats_type == 'eigs': + try: + stats = torch.stack(stats).sum(dim=0) + except: + return '' + count = sum(counts) + stats = stats / count + stats, _ = torch.symeig(stats) + stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance + elif sizes_same: stats = torch.stack(stats).sum(dim=0) count = sum(counts) stats = stats / count @@ -121,12 +144,16 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], # normal case. stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] + stats_types = stats_types + ["eigs"] + for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] sizes_same = all([ x == sizes[0] for x in sizes ]) s = get_diagnostics_for_dim(dim, tensors, options, sizes_same, stats_type) + if s == '': + continue min_size = min(sizes) max_size = max(sizes) @@ -181,10 +208,17 @@ def print_diagnostics(self): # ensure there is at least one dim. self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] + try: + device = torch.device('cuda') + torch.ones(1, 1, device) + except: + device = torch.device('cpu') + ndim = self.saved_tensors[0].ndim + tensors = [x.to(device) for x in self.saved_tensors] for dim in range(ndim): print_diagnostics_for_dim(self.name, dim, - self.saved_tensors, + tensors, self.opts) From 1e5455ba2904efab594e68e16d548de32f104a14 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 10:28:48 +0800 Subject: [PATCH 039/185] Update diagnostics --- .../ASR/transducer_stateless/diagnostics.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index dfbc2dced5..8ea35582a9 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -11,24 +11,21 @@ class TensorDiagnosticOptions(object): Options object for tensor diagnostics: Args: - memory_limit: the maximum number of bytes per tensor (limits how many copies + memory_limit: the maximum number of bytes we store per tensor (limits how many copies of the tensor we cache). - + max_eig_dim: the maximum dimension for which we print out eigenvalues + (limited for speed reasons). """ - def __init__(self, memory_limit: int, - print_pos_ratio: bool = True): + def __init__(self, + memory_limit: int = (2 ** 20), + max_eig_dim: int = 512): + self.memory_limit = memory_limit - self.print_pos_ratio = print_pos_ratio + self.max_eig_dim = max_eig_dim def dim_is_summarized(self, size: int): return size > 10 and size != 31 - def stats_types(self): - if self.print_pos_ratio: - return ["mean-abs", "pos-ratio", "value"] - else: - return ["mean-abs"] - def get_tensor_stats(x: Tensor, dim: int, @@ -41,8 +38,9 @@ def get_tensor_stats(x: Tensor, dim: int, x: Tensor, tensor to be analyzed dim: dimension with 0 <= dim < x.ndim stats_type: - "mean-abs" or "abs-value" -> take abs() before summing - "pos-ratio" -> take (x > 0) before summing + "abs" -> take abs() before summing + "positive" -> take (x > 0) before summing + "rms" -> square before summing, we'll take sqrt later "value -> just sum x itself Returns (stats, count) where stats is a Tensor of shape (x.shape[dim],), and the count @@ -56,9 +54,11 @@ def get_tensor_stats(x: Tensor, dim: int, x = x.reshape(-1, x.shape[-1]) # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. return torch.matmul(x.transpose(0, 1), x), count - elif stats_type == "mean-abs" or stats_type == "abs-value": + elif stats_type == "abs": x = x.abs() - elif stats_type == "pos-ratio": + elif stats_type == "rms": + x = x ** 2 + elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: assert stats_type == "value" @@ -79,9 +79,9 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim options: options object sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value, + stats_type: either "abs" or "positive" or "eigs" or "value, imdictates the type of stats - we accumulate, mean-abs is mean absolute value, "pos-ratio" + we accumulate, abs is mean absolute value, "positive" is proportion of positive to nonnegative values, "eigs" is eigenvalues after doing outer product on this dim, sum over all other dimes. @@ -92,13 +92,11 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], mismatch and stats_type == "eigs" """ # stats_and_counts is a list of pair (Tensor, int) - if tensors[0].shape[dim] > 512 and stats_type == 'eigs': - return '' # won't produce eigs stats if dim too large. stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] - if stats_type == 'eigs': + if stats_type == "eigs": try: stats = torch.stack(stats).sum(dim=0) except: @@ -114,6 +112,9 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], else: stats = [ x[0] / x[1] for x in stats_and_counts ] stats = torch.cat(stats, dim=0) + if stats_type == 'rms': + stats = stats.sqrt() + # if `summarize` we print percentiles of the stats; else, # we print out individual elements. summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) @@ -140,11 +141,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions): ndim = tensors[0].ndim - # options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the - # normal case. - stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] - - stats_types = stats_types + ["eigs"] + if ndim > 1: + stats_types = ["abs", "positive", "value", "rms"] + if tensors[0].shape[dim] <= options.max_eig_dim: + stats_types.append("eigs") + else: + stats_types = [ "value", "abs" ] for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] @@ -158,7 +160,7 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], min_size = min(sizes) max_size = max(sizes) size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "mean-abs" or "pos-ratio". + # stats_type will be "abs" or "positive". print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") @@ -223,7 +225,7 @@ def print_diagnostics(self): class ModelDiagnostic(object): - def __init__(self, opts: TensorDiagnosticOptions): + def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): self.diagnostics = dict() self.opts = opts @@ -286,7 +288,7 @@ def param_backward_hook(grad, def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, True) + opts = TensorDiagnosticOptions(2**20, 512) diagnostic = TensorDiagnostic(opts, "foo") From 059b57ad37c98ba228a708821eafea0f3c152146 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 14:32:05 +0800 Subject: [PATCH 040/185] Add BasicNorm module --- .../ASR/conformer_ctc/subsampling.py | 101 +++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 2df2678dd3..622495f214 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -336,6 +336,83 @@ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: return x_grad - neg_delta_grad, None, None, None, None +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. + + We also introduce a learned scaling factor on the output; and we + remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not + that useful unless the LayerNorm immediately follows a nonlinearity). + + + Args: + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + initial_eps_scale: a constant that determines the initial + "epsilon" that we add as ballast in: + scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5 + Note: our epsilon is actually large, not small, but we keep the name + to indicate the connection with normal LayerNorm. We set + epsilon initially to num_channels * initial_eps_scale. + speed: a scaling factor that can be interpreted as scaling the learning + rate for this module. CAUTION: the default value of 10.0 intended to be + used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. + If you are using SGD you would probably have to set `speed` to + a value less than one, or the training would be unstable. + """ + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + initial_eps_scale: float = 0.25, + speed: float = 10.0): + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.speed = speed + eps = num_channels * initial_eps_scale + # log_eps = log(eps) / speed + log_eps = torch.tensor(eps).log() / speed + self.log_eps = nn.Parameter(log_eps.detach()) + # initial output-scale, to get LayerNorm-like behavior, is + # sqrt(num_channels). + initial_scale = torch.tensor(num_channels ** 0.5).log() / speed + self.log_scale = nn.Parameter(initial_scale.detach()) + + def _inner(self, x: Tensor) -> Tensor: + # inner product on last dim of x, keeping the dimension, + # i.e. torch.sum(x**2, dim=-1, keepdim=True), but more + # efficient. + if hasattr(torch, 'inner'): + return torch.inner(x).unsqueeze(-1) + else: + # TODO: we can do this with matrix multiplication, maybe.a + return torch.sum(x**2, dim=-1, keepdim=True) + + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + x = x.transpose(-1, self.channel_dim) + eps = (self.log_eps * self.speed).exp() + out_scale = (self.log_scale * self.speed).exp() + + scales = out_scale * (self._inner(x) + eps) ** -0.5 + x = x * scales + x = x.transpose(-1, self.channel_dim) + return x + + class DerivBalancer(torch.nn.Module): """ @@ -367,16 +444,16 @@ def forward(self, x: Tensor) -> Tensor: def _test_exp_scale_swish(): - class Swish(torch.nn.Module): + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swich activation function.""" - return x * torch.sigmoid(x) + return x * torch.sigmoid(x - 1.0) x1 = torch.randn(50, 60).detach() x2 = x1.detach() m1 = ExpScaleSwish(50, 1, speed=4.0) - m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0)) + m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) x1.requires_grad = True x2.requires_grad = True @@ -425,8 +502,26 @@ def _test_deriv_balancer(): print("x grad = ", x.grad) +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + if __name__ == '__main__': _test_deriv_balancer() _test_exp_scale_swish() _test_exp_scale_relu() + _test_basic_norm() From b55472bb427a2407797e028bf929ff0b7f55f18b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 14:43:54 +0800 Subject: [PATCH 041/185] Replace most normalizations with scales (still have norm in conv) --- .../ASR/conformer_ctc/subsampling.py | 9 ++- .../ASR/transducer_stateless/conformer.py | 57 ++++++------------- 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 622495f214..29621bf526 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -58,6 +58,7 @@ def __init__(self, idim: int, odim: int) -> None: ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out_norm = BasicNorm(odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -76,7 +77,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = x * 0.1 + x = self.out_norm(x) return x @@ -200,9 +201,11 @@ def forward(self, x: Tensor) -> Tensor: return PeLUFunction.apply(x, self.cutoff, self.alpha) class ExpScale(torch.nn.Module): - def __init__(self, *shape, speed: float = 1.0): + def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0): super(ExpScale, self).__init__() - self.scale = nn.Parameter(torch.zeros(*shape)) + scale = torch.tensor(initial_scale) + scale = scale.log() / speed + self.scale = nn.Parameter(scale.detach()) self.speed = speed def forward(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 62d9f382fc..acaf064b3d 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer +from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm import torch from torch import Tensor, nn @@ -150,6 +150,8 @@ def __init__( normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() + self.d_model = d_model + self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) @@ -174,22 +176,15 @@ def __init__( self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 + self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) + self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) + self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) + self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = BasicNorm(d_model) self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before def forward( self, @@ -217,18 +212,15 @@ def forward( # macaron style feed forward module residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) + + + src = src + self.dropout(self.feed_forward_macaron( + self.scale_ff_macaron(src))) + # multi-headed self-attention module residual = src - if self.normalize_before: - src = self.norm_mha(src) + src = self.scale_mha(src) src_att = self.self_attn( src, src, @@ -238,27 +230,14 @@ def forward( key_padding_mask=src_key_padding_mask, )[0] src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(self.scale_conv(src))) # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) + src = src + self.dropout(self.feed_forward(self.scale_ff(src))) - if self.normalize_before: - src = self.norm_final(src) + src = self.norm_final(src) return src @@ -288,7 +267,7 @@ def __init__( self.aux_layers = set(aux_layers + [num_layers - 1]) assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.norm_final.weight.numel() + num_channels = encoder_layer.d_model self.combiner = RandomCombine(num_inputs=len(self.aux_layers), num_channels=num_channels, final_weight=0.5, From 87b843f02301738395a6d7c0651a295e11a92a08 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 14:44:55 +0800 Subject: [PATCH 042/185] Change exp dir --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 0dbd8479b7..4fd4bf7646 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 425e274c82029217623b263944aaa2b407ef5847 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 16:01:53 +0800 Subject: [PATCH 043/185] Replace norm in ConvolutionModule with a scaling factor. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 5 +++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index acaf064b3d..4cf66e2fea 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -857,7 +857,8 @@ def __init__( bias=bias, ) - self.norm = nn.LayerNorm(channels) + self.scale = ExpScale(1, speed=10.0, initial_scale=1.0) + # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() @@ -891,7 +892,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) - x = self.norm(x) + x = self.scale(x) x = x.permute(0, 2, 1) x = self.activation(x) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 4fd4bf7646..c355c7ad34 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2fa9c636a44b105b18ef403afe6f4c1ff7d73529 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 23:24:55 +0800 Subject: [PATCH 044/185] use nonzero threshold in DerivBalancer --- .../ASR/conformer_ctc/subsampling.py | 47 +++++++++++++------ .../ASR/transducer_stateless/conformer.py | 6 +-- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 29621bf526..390d311155 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -219,7 +219,7 @@ def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: x = x * (scale * speed).exp() return x -class ExpScaleSwishFunction(torch.autograd.Function): +class SwishExpScaleFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) @@ -237,16 +237,16 @@ def backward(ctx, y_grad: Tensor) -> Tensor: return x.grad, scale.grad, None -class ExpScaleSwish(torch.nn.Module): - # combines ExpScale an Swish - # caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) +class SwishExpScale(torch.nn.Module): + # combines ExpScale and a Swish (actually the ExpScale is after the Swish). + # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) def __init__(self, *shape, speed: float = 1.0): - super(ExpScaleSwish, self).__init__() + super(SwishExpScale, self).__init__() self.scale = nn.Parameter(torch.zeros(*shape)) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return ExpScaleSwishFunction.apply(x, self.scale, self.speed) + return SwishExpScaleFunction.apply(x, self.scale, self.speed) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() @@ -313,13 +313,15 @@ def forward(self, x: Tensor) -> Tensor: class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: 0.05, max_factor: 0.05, - epsilon: 1.0e-10) -> Tensor: + threshold: float = 0.05, + max_factor: float = 0.05, + zero: float = 0.02, + epsilon: float = 1.0e-10) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) + proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) ctx.save_for_backward(factor) @@ -328,7 +330,7 @@ def forward(ctx, x: Tensor, channel_dim: int, return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: factor, = ctx.saved_tensors neg_delta_grad = x_grad.abs() * factor if ctx.epsilon != 0.0: @@ -336,7 +338,7 @@ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: deriv_is_zero = (sum_abs_grad == 0.0) neg_delta_grad += ctx.epsilon * deriv_is_zero - return x_grad - neg_delta_grad, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -429,20 +431,37 @@ class DerivBalancer(torch.nn.Module): When all grads are zero for a channel, this module sets all the input derivatives for that channel to -epsilon; the idea is to bring completely dead neurons back to life this way. + + Args: + channel_dim: the dimension/axi corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + threshold: the threshold, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives, + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.01]. + zero: we use this value in the comparison (x > 0), i.e. we actually use + (x > zero). The reason for using a threshold slightly greater + than zero is that it will tend to prevent situations where the + inputs shrink close to zero and the nonlinearity (e.g. swish) + behaves like a linear function and we learn nothing. """ def __init__(self, channel_dim: int, threshold: float = 0.05, - max_factor: float = 0.05, + max_factor: float = 0.02, + zero: float = 0.02, epsilon: float = 1.0e-10): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor + self.zero = zero self.epsilon = epsilon def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.epsilon) + self.max_factor, self.zero, + self.epsilon) @@ -455,7 +474,7 @@ def forward(self, x: Tensor) -> Tensor: x1 = torch.randn(50, 60).detach() x2 = x1.detach() - m1 = ExpScaleSwish(50, 1, speed=4.0) + m1 = SwishExpScale(50, 1, speed=4.0) m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) x1.requires_grad = True x2.requires_grad = True diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 4cf66e2fea..7a7a09c276 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm +from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm import torch from torch import Tensor, nn @@ -160,7 +160,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c355c7ad34..36a1ae8696 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 76560f255c3ab5f88f8bf318c14fc5d81eb9c429 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 23:48:46 +0800 Subject: [PATCH 045/185] Add min-abs-value 0.2 --- .../ASR/conformer_ctc/subsampling.py | 72 ++++++++++++------- .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 390d311155..d1ff7f2336 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -312,33 +312,36 @@ def forward(self, x: Tensor) -> Tensor: class DerivBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, channel_dim: int, + def forward(ctx, x: Tensor, + channel_dim: int, threshold: float = 0.05, max_factor: float = 0.05, - zero: float = 0.02, - epsilon: float = 1.0e-10) -> Tensor: + min_abs: float = 0.2) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) + xgt0 = x > 0 + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) - ctx.save_for_backward(factor) - ctx.epsilon = epsilon + below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) + + ctx.save_for_backward(factor, xgt0, below_threshold) + ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - factor, = ctx.saved_tensors - neg_delta_grad = x_grad.abs() * factor - if ctx.epsilon != 0.0: - sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True) - deriv_is_zero = (sum_abs_grad == 0.0) - neg_delta_grad += ctx.epsilon * deriv_is_zero + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + factor, xgt0, below_threshold = ctx.saved_tensors + dtype = x_grad.dtype + too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) + + neg_delta_grad = x_grad.abs() * (factor + too_small_factor) - return x_grad - neg_delta_grad, None, None, None, None, None + + return x_grad - neg_delta_grad, None, None, None, None class BasicNorm(torch.nn.Module): @@ -449,19 +452,17 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.02, - zero: float = 0.02, - epsilon: float = 1.0e-10): + min_abs: float = 0.2): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor - self.zero = zero - self.epsilon = epsilon + self.min_abs = min_abs def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.zero, - self.epsilon) + self.max_factor, self.min_abs) + @@ -505,23 +506,41 @@ def _test_exp_scale_relu(): -def _test_deriv_balancer(): +def _test_deriv_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) N = 500 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) - print("x = ", x) - print("y grad = ", y_grad) - print("x grad = ", x.grad) + print("_test_deriv_balancer_sign: x = ", x) + print("_test_deriv_balancer_sign: y grad = ", y_grad) + print("_test_deriv_balancer_sign: x grad = ", x.grad) + +def _test_deriv_balancer_magnitude(): + channel_dim = 0 + magnitudes = torch.arange(0, 1, 0.01) + N = 500 + x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + y_grad[-1,:] = 0 + + y = m(x) + y.backward(gradient=y_grad) + print("_test_deriv_balancer_magnitude: x = ", x) + print("_test_deriv_balancer_magnitude: y grad = ", y_grad) + print("_test_deriv_balancer_magnitude: x grad = ", x.grad) def _test_basic_norm(): @@ -543,7 +562,8 @@ def _test_basic_norm(): if __name__ == '__main__': - _test_deriv_balancer() + _test_deriv_balancer_sign() + _test_deriv_balancer_magnitude() _test_exp_scale_swish() _test_exp_scale_relu() _test_basic_norm() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 36a1ae8696..618d904903 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From bfce5f63e498877f4d3c9681a0da341ce90d2e67 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 23:49:09 +0800 Subject: [PATCH 046/185] Fix dirname --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 618d904903..d75341a078 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From e3e14cf7a4850d2a7785a53a2bb7d47c53b44310 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:16:33 +0800 Subject: [PATCH 047/185] Change min-abs threshold from 0.2 to 0.5 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index d1ff7f2336..d7be46f17b 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -316,7 +316,7 @@ def forward(ctx, x: Tensor, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.05, - min_abs: float = 0.2) -> Tensor: + min_abs: float = 0.5) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim @@ -452,7 +452,7 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.02, - min_abs: float = 0.2): + min_abs: float = 0.5): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index d75341a078..80febc6771 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From ab9a17413ab966e013d275919791040c23002407 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:37:52 +0800 Subject: [PATCH 048/185] Scale up pos_bias_u and pos_bias_v before use. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 +++- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7a7a09c276..d0be5af001 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -614,7 +614,9 @@ def multi_head_attention_forward( assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + q = q * scaling if torch.equal(query, key) and torch.equal(key, value): # self-attention @@ -764,7 +766,7 @@ def multi_head_attention_forward( attn_output_weights = ( matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + ) # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 80febc6771..c9654cc94a 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 137eae0b95ee5a0c4dcd137c3d1279301006c5ee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:41:55 +0800 Subject: [PATCH 049/185] Reduce max_factor to 0.01 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index d7be46f17b..2e4eb754bc 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -451,7 +451,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, threshold: float = 0.05, - max_factor: float = 0.02, + max_factor: float = 0.01, min_abs: float = 0.5): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim From 2940d3106f08a06d37618229e872ace7e371fa66 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:43:57 +0800 Subject: [PATCH 050/185] Fix q*scaling logic --- egs/librispeech/ASR/transducer_stateless/conformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index d0be5af001..e14c7a02e4 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -616,7 +616,6 @@ def multi_head_attention_forward( ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 - q = q * scaling if torch.equal(query, key) and torch.equal(key, value): # self-attention @@ -721,7 +720,7 @@ def multi_head_attention_forward( ) key_padding_mask = key_padding_mask.to(torch.bool) - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = (q.contiguous() * scaling).view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) From bcf417fce2b3115e7dda8d3b0e0a6cbafd5d71ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:47:46 +0800 Subject: [PATCH 051/185] Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 2e4eb754bc..ce25ad8ea1 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,13 +48,13 @@ def __init__(self, idim: int, odim: int) -> None: in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index e14c7a02e4..051512969e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -159,7 +159,7 @@ def __init__( self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -168,7 +168,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -720,7 +720,7 @@ def multi_head_attention_forward( ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q.contiguous() * scaling).view(tgt_len, bsz, num_heads, head_dim) + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) From bec33e6855afba8c4f2739cbcf6a7a67398ec210 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 16:37:17 +0800 Subject: [PATCH 052/185] init 1st conv module to smaller variance --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 8 ++++++++ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ce25ad8ea1..6a697aa0e3 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -59,6 +59,14 @@ def __init__(self, idim: int, odim: int) -> None: ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) + self._reset_parameters() + + def _reset_parameters(self): + # init weights with smaller than default variance, because otherwise + # they learn too slowly in relative terms (assuming we're training with adam). + nn.init.normal_(self.conv[0].weight, std=0.05) + nn.init.constant_(self.conv[0].bias, 0.0) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c9654cc94a..5d6d724901 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs", + default="transducer_stateless/randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs_cinit", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5eafccb36942b3da42024c48b1af237ef1f613ec Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 17:46:33 +0800 Subject: [PATCH 053/185] Change how scales are applied; fix residual bug --- .../ASR/transducer_stateless/conformer.py | 17 +++++++++++++---- .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 051512969e..2c602bbeac 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -229,10 +229,17 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + self.dropout(src_att) + # natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it + # to 0.2 to 0.6, which is suitable to add to the inputs assuming the output + # of the previous convolution layer had a magnitude of around 1.0 + # (this magnitude of 1.0, or a bit less, like 0.3, is learned but is + # dictated by considerations of what is done to the output of the + # encoder. + post_scale_mha = 0.1 + src = residual + post_scale_mha * self.dropout(src_att) # convolution module - src = residual + self.dropout(self.conv_module(self.scale_conv(src))) + src = src + self.dropout(self.conv_module(self.scale_conv(src))) # feed forward module src = src + self.dropout(self.feed_forward(self.scale_ff(src))) @@ -891,13 +898,15 @@ def forward(self, x: Tensor) -> Tensor: # 1D Depthwise Conv x = self.depthwise_conv(x) + + # TODO: can have a learned scale in here, or a fixed one. + x = self.activation(x) + # x is (batch, channels, time) x = x.permute(0, 2, 1) x = self.scale(x) x = x.permute(0, 2, 1) - x = self.activation(x) - x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 5d6d724901..b5e9e846fe 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs_cinit", + default="transducer_stateless/randcombine1_expscale3_rework", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a0d5e2932ccfd2b2eadb271434dc30a14c980c7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 18:17:49 +0800 Subject: [PATCH 054/185] Reduce min_abs from 0.5 to 0.2 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6a697aa0e3..6b1cb128fe 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -460,7 +460,7 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.01, - min_abs: float = 0.5): + min_abs: float = 0.2): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold From 98156711efb11e92d8b50eb426041b62da4a5564 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 19:05:55 +0800 Subject: [PATCH 055/185] Introduce in_scale=0.5 for SwishExpScale --- .../ASR/conformer_ctc/subsampling.py | 19 ++++++++++++------- .../ASR/transducer_stateless/conformer.py | 4 ++-- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6b1cb128fe..52a58d104e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -221,18 +221,21 @@ def forward(self, x: Tensor) -> Tensor: -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: # double-swish, implemented/approximated as offset-swish + if in_scale != 1.0: + x = x * in_scale x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x class SwishExpScaleFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed - return _exp_scale_swish(x, scale, speed) + ctx.in_scale = in_scale + return _exp_scale_swish(x, scale, speed, in_scale) @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: @@ -240,21 +243,23 @@ def backward(ctx, y_grad: Tensor) -> Tensor: x.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) + y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale) y.backward(gradient=y_grad) - return x.grad, scale.grad, None + return x.grad, scale.grad, None, None class SwishExpScale(torch.nn.Module): # combines ExpScale and a Swish (actually the ExpScale is after the Swish). # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) - def __init__(self, *shape, speed: float = 1.0): + # + def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): super(SwishExpScale, self).__init__() + self.in_scale = in_scale self.scale = nn.Parameter(torch.zeros(*shape)) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed) + return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 2c602bbeac..7b9aff71f2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -160,7 +160,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b5e9e846fe..190406491b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework", + default="transducer_stateless/randcombine1_expscale3_rework_0.5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From cc558faf262f7db5bfbc637e86a7102f23c1f77e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 19:11:50 +0800 Subject: [PATCH 056/185] Fix scale from 0.5 to 2.0 as I really intended.. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7b9aff71f2..fa25e6ca02 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -160,7 +160,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ def __init__( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 190406491b..c72a9dd28d 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_0.5", + default="transducer_stateless/randcombine1_expscale3_rework_2.0", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2d3a76292d0649a358f39835cd5944c0ac406b37 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 20:12:45 +0800 Subject: [PATCH 057/185] Set scaling on SwishExpScale --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 52a58d104e..caac230ed0 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -255,7 +255,9 @@ class SwishExpScale(torch.nn.Module): def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): super(SwishExpScale, self).__init__() self.in_scale = in_scale - self.scale = nn.Parameter(torch.zeros(*shape)) + initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed + initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach() + self.scale = nn.Parameter(initial_log_scale) self.speed = speed def forward(self, x: Tensor) -> Tensor: From 7eb5a84cbeb4242736b28d1d1ea5a118cb1cc256 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 21:00:43 +0800 Subject: [PATCH 058/185] Add identity pre_norm_final for diagnostics. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fa25e6ca02..389a7cb7fb 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -181,6 +181,7 @@ def __init__( self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) + self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) self.dropout = nn.Dropout(dropout) @@ -244,7 +245,7 @@ def forward( # feed forward module src = src + self.dropout(self.feed_forward(self.scale_ff(src))) - src = self.norm_final(src) + src = self.norm_final(self.pre_norm_final(src)) return src @@ -930,8 +931,9 @@ def forward(self, x: Tensor) -> Tensor: return x * torch.sigmoid(x + self.offset) -def identity(x): - return x +class Identity(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return x class RandomCombine(torch.nn.Module): From 76a2b9d36239566aae2125837f653ecbeb3a1ca9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 11:19:49 +0800 Subject: [PATCH 059/185] Add learnable post-scale for mha --- egs/librispeech/ASR/transducer_stateless/conformer.py | 10 ++-------- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 389a7cb7fb..963cb2cd92 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -177,6 +177,7 @@ def __init__( self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) + self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0) self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) @@ -230,14 +231,7 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - # natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it - # to 0.2 to 0.6, which is suitable to add to the inputs assuming the output - # of the previous convolution layer had a magnitude of around 1.0 - # (this magnitude of 1.0, or a bit less, like 0.3, is learned but is - # dictated by considerations of what is done to the output of the - # encoder. - post_scale_mha = 0.1 - src = residual + post_scale_mha * self.dropout(src_att) + src = residual + post_scale_mha(self.dropout(src_att)) # convolution module src = src + self.dropout(self.conv_module(self.scale_conv(src))) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c72a9dd28d..be771b5172 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_2.0", + default="transducer_stateless/randcombine1_expscale3_rework_2.0_b", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 0abba9e7a2eb849164459ee5ec22d7b2da28d9c5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 11:20:44 +0800 Subject: [PATCH 060/185] Fix self.post-scale-mha --- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 963cb2cd92..3f9becded9 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -231,7 +231,7 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + post_scale_mha(self.dropout(src_att)) + src = residual + self.post_scale_mha(self.dropout(src_att)) # convolution module src = src + self.dropout(self.conv_module(self.scale_conv(src))) From ca8cf2a73b4d65406d9ca5b4648af4768926d3b3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 15:38:13 +0800 Subject: [PATCH 061/185] Another rework, use scales on linear/conv --- .../ASR/conformer_ctc/subsampling.py | 156 ++++++++++++------ .../ASR/transducer_stateless/conformer.py | 73 ++++---- 2 files changed, 140 insertions(+), 89 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index caac230ed0..831537d795 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -44,20 +44,20 @@ def __init__(self, idim: int, odim: int) -> None: assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( + ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), - nn.Conv2d( + ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) self._reset_parameters() @@ -221,21 +221,18 @@ def forward(self, x: Tensor) -> Tensor: -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: # double-swish, implemented/approximated as offset-swish - if in_scale != 1.0: - x = x * in_scale x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x class SwishExpScaleFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed - ctx.in_scale = in_scale - return _exp_scale_swish(x, scale, speed, in_scale) + return _exp_scale_swish(x, scale, speed) @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: @@ -243,25 +240,24 @@ def backward(ctx, y_grad: Tensor) -> Tensor: x.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale) + y = _exp_scale_swish(x, scale, ctx.speed) y.backward(gradient=y_grad) - return x.grad, scale.grad, None, None + return x.grad, scale.grad, None class SwishExpScale(torch.nn.Module): # combines ExpScale and a Swish (actually the ExpScale is after the Swish). # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) # - def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): + def __init__(self, *shape, speed: float = 1.0): super(SwishExpScale, self).__init__() - self.in_scale = in_scale - initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed - initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach() + + initial_log_scale = torch.zeros(()).detach() self.scale = nn.Parameter(initial_log_scale) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale) + return SwishExpScaleFunction.apply(x, self.scale, self.speed) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() @@ -383,12 +379,11 @@ class BasicNorm(torch.nn.Module): interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - initial_eps_scale: a constant that determines the initial - "epsilon" that we add as ballast in: - scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5 - Note: our epsilon is actually large, not small, but we keep the name - to indicate the connection with normal LayerNorm. We set - epsilon initially to num_channels * initial_eps_scale. + initial_eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with normal LayerNorm. + speed: a scaling factor that can be interpreted as scaling the learning rate for this module. CAUTION: the default value of 10.0 intended to be used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. @@ -398,42 +393,101 @@ class BasicNorm(torch.nn.Module): def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - initial_eps_scale: float = 0.25, - speed: float = 10.0): + eps: float = 0.25): super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.speed = speed - eps = num_channels * initial_eps_scale - # log_eps = log(eps) / speed - log_eps = torch.tensor(eps).log() / speed - self.log_eps = nn.Parameter(log_eps.detach()) - # initial output-scale, to get LayerNorm-like behavior, is - # sqrt(num_channels). - initial_scale = torch.tensor(num_channels ** 0.5).log() / speed - self.log_scale = nn.Parameter(initial_scale.detach()) - - def _inner(self, x: Tensor) -> Tensor: - # inner product on last dim of x, keeping the dimension, - # i.e. torch.sum(x**2, dim=-1, keepdim=True), but more - # efficient. - if hasattr(torch, 'inner'): - return torch.inner(x).unsqueeze(-1) - else: - # TODO: we can do this with matrix multiplication, maybe.a - return torch.sum(x**2, dim=-1, keepdim=True) + self.eps = eps def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - x = x.transpose(-1, self.channel_dim) - eps = (self.log_eps * self.speed).exp() - out_scale = (self.log_scale * self.speed).exp() + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5 + return x * scales + + +class ScaledLinear(nn.Linear): + def __init__(self, *args, scale_speed=5.0, **kwargs): + super(ScaledLinear, self).__init__(*args, **kwargs) + self.weight_scale = nn.Parameter(torch.zeros(())) + self.scale_speed = scale_speed + if self.bias is not None: + self.bias_scale = nn.Parameter(torch.zeros(())) + else: + self.register_parameter('bias_scale', None) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) + + +class ScaledConv1d(nn.Conv1d): + def __init__(self, *args, scale_speed = 5.0, **kwargs): + super(ScaledConv1d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + self.weight_scale = nn.Parameter(torch.zeros(())) + if self.bias is not None: + self.bias_scale = nn.Parameter(torch.zeros(())) + else: + self.register_parameter('bias_scale', None) + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + + +class ScaledConv2d(nn.Conv2d): + def __init__(self, *args, scale_speed=5.0, **kwargs): + super(ScaledConv2d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + self.weight_scale = nn.Parameter(torch.zeros(())) + if self.bias is not None: + self.bias_scale = nn.Parameter(torch.zeros(())) + else: + self.register_parameter('bias_scale', None) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) - scales = out_scale * (self._inner(x) + eps) ** -0.5 - x = x * scales - x = x.transpose(-1, self.channel_dim) - return x @@ -576,6 +630,8 @@ def _test_basic_norm(): + + if __name__ == '__main__': _test_deriv_balancer_sign() _test_deriv_balancer_magnitude() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3f9becded9..93f7dd1707 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm +from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -157,30 +157,25 @@ def __init__( ) self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), + ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), + ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) - self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0) - self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) - self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) - self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) @@ -216,13 +211,10 @@ def forward( residual = src - src = src + self.dropout(self.feed_forward_macaron( - self.scale_ff_macaron(src))) + src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module - residual = src - src = self.scale_mha(src) src_att = self.self_attn( src, src, @@ -231,13 +223,13 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + self.post_scale_mha(self.dropout(src_att)) + src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(self.scale_conv(src))) + src = src + self.dropout(self.conv_module(src)) # feed forward module - src = src + self.dropout(self.feed_forward(self.scale_ff(src))) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.pre_norm_final(src)) @@ -420,6 +412,7 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, + scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -430,18 +423,27 @@ def __init__( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.scale_speed = scale_speed + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() + def _pos_bias_u(self): + return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + + def _pos_bias_v(self): + return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -508,11 +510,11 @@ def forward( pos_emb, self.embed_dim, self.num_heads, - self.in_proj.weight, - self.in_proj.bias, + self.in_proj.get_weight(), + self.in_proj.get_bias(), self.dropout, - self.out_proj.weight, - self.out_proj.bias, + self.out_proj.get_weight(), + self.out_proj.get_bias(), training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -743,11 +745,11 @@ def multi_head_attention_forward( p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose( + q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose( + q_with_bias_v = (q + self._pos_bias_v()).transpose( 1, 2 ) # (batch, head, time1, d_k) @@ -842,7 +844,7 @@ def __init__( # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = nn.Conv1d( + self.pointwise_conv1 = ScaledConv1d( channels, 2 * channels, kernel_size=1, @@ -850,7 +852,7 @@ def __init__( padding=0, bias=bias, ) - self.depthwise_conv = nn.Conv1d( + self.depthwise_conv = ScaledConv1d( channels, channels, kernel_size, @@ -860,12 +862,10 @@ def __init__( bias=bias, ) - self.scale = ExpScale(1, speed=10.0, initial_scale=1.0) - # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() - self.pointwise_conv2 = nn.Conv1d( + self.pointwise_conv2 = ScaledConv1d( channels, channels, kernel_size=1, @@ -897,11 +897,6 @@ def forward(self, x: Tensor) -> Tensor: # TODO: can have a learned scale in here, or a fixed one. x = self.activation(x) - # x is (batch, channels, time) - x = x.permute(0, 2, 1) - x = self.scale(x) - x = x.permute(0, 2, 1) - x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) @@ -982,7 +977,7 @@ def __init__(self, num_inputs: int, assert pure_prob >= 0 and pure_prob <= 1 assert final_weight > 0 and final_weight < 1 assert num_inputs >= 1 - self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True) + self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True) for _ in range(num_inputs - 1)]) self.num_inputs = num_inputs From d906bc2a4f14fd9394363e3ec6d473d9ed2aff3b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 15:38:39 +0800 Subject: [PATCH 062/185] Change dir name --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index be771b5172..1a57d654fa 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_2.0_b", + default="transducer_stateless/randcombine1_expscale3_rework2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a392cb9fbc5bc23228ff142354c7962b59fdaa74 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 16:53:03 +0800 Subject: [PATCH 063/185] Reduce initial scaling of modules --- .../ASR/conformer_ctc/subsampling.py | 22 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 2 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 831537d795..dab0e1e1df 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -407,12 +407,13 @@ def forward(self, x: Tensor) -> Tensor: class ScaledLinear(nn.Linear): - def __init__(self, *args, scale_speed=5.0, **kwargs): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.scale_speed = scale_speed if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) @@ -431,12 +432,14 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, **kwargs): + def __init__(self, *args, scale_speed = 5.0, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) self.scale_speed = scale_speed - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) @@ -459,12 +462,13 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, **kwargs): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) self.scale_speed = scale_speed - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 93f7dd1707..aa35f5e7e0 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ def __init__( max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 1a57d654fa..b871efd135 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2", + default="transducer_stateless/randcombine1_expscale3_rework2b", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a24572abd1285ff12c89e31908694689fa2e6d41 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 17:28:43 +0800 Subject: [PATCH 064/185] Bug-fix RE bias --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index dab0e1e1df..5f1e376a9e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -486,7 +486,7 @@ def _conv_forward(self, input, weight): return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, self.get_bias(), self.stride, _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, + return F.conv2d(input, weight, self.get_bias(), self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: From b7b2d8970b608ff3954039e99dbbd95186b61bae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 17:47:35 +0800 Subject: [PATCH 065/185] Cosmetic change --- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index aa35f5e7e0..a270cd8ae3 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -171,7 +171,7 @@ def __init__( max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -208,9 +208,6 @@ def forward( """ # macaron style feed forward module - residual = src - - src = src + self.dropout(self.feed_forward_macaron(src)) @@ -872,6 +869,7 @@ def __init__( stride=1, padding=0, bias=bias, + initial_scale=0.5 ) def forward(self, x: Tensor) -> Tensor: From db7a3b6eea34e532240dae3409c6d64e8eab9806 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 18:50:02 +0800 Subject: [PATCH 066/185] Reduce initial_scale. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index a270cd8ae3..9dd6bae4d4 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -421,7 +421,7 @@ def __init__( ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -869,7 +869,7 @@ def __init__( stride=1, padding=0, bias=bias, - initial_scale=0.5 + initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: From be0a79cbcae9fb6a02f139ef4385af7fa6f80032 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 19:00:48 +0800 Subject: [PATCH 067/185] Replace ExpScaleRelu with DoubleSwish() --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 5f1e376a9e..13259d166d 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -49,13 +49,13 @@ def __init__(self, idim: int, odim: int) -> None: ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), - ExpScaleRelu(odim, 1, 1, speed=20.0), + DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), - ExpScaleRelu(odim, 1, 1, speed=20.0), + DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) From 2117f46361c2b2deb63194de43098e1a17714d61 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 19:02:14 +0800 Subject: [PATCH 068/185] DoubleSwish fix --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 13259d166d..6bf0aefe4d 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -537,13 +537,13 @@ def forward(self, x: Tensor) -> Tensor: self.max_factor, self.min_abs) +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x - 1.0) def _test_exp_scale_swish(): - class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x - 1.0) x1 = torch.randn(50, 60).detach() x2 = x1.detach() From 6042c96db2f68c24f08aadf93904d0383dcd7fc9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 20:54:46 +0800 Subject: [PATCH 069/185] Use learnable scales for joiner and decoder --- .../ASR/transducer_stateless/decoder.py | 187 +++++++++++++++++- .../ASR/transducer_stateless/joiner.py | 4 +- .../ASR/transducer_stateless/train.py | 2 +- .../ASR/transducer_stateless/transformer.py | 4 +- 4 files changed, 190 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 003b03a2e7..bc4bcb3f63 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -52,7 +55,7 @@ def __init__( 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() - self.embedding = nn.Embedding( + self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, @@ -62,7 +65,7 @@ def __init__( assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: - self.conv = nn.Conv1d( + self.conv = ScaledConv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, @@ -97,3 +100,183 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None, + scale_speed: float = 5.0) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale_speed = scale_speed + self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) + + if _weight is None: + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = nn.Parameter(_weight) + self.sparse = sparse + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = (self.scale * self.scale_speed).exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Creates Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): See module initialization documentation. + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + embedding.weight.requires_grad = not freeze + return embedding diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 9fd9da4f17..8311461d30 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn - +from subsampling import ScaledLinear class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): @@ -24,7 +24,7 @@ def __init__(self, input_dim: int, output_dim: int): self.input_dim = input_dim self.output_dim = output_dim - self.output_linear = nn.Linear(input_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim) def forward( self, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b871efd135..c2202fe1ec 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2b", + default="transducer_stateless/randcombine1_expscale3_rework2c", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py index e851dcc328..3fa847f4f2 100644 --- a/egs/librispeech/ASR/transducer_stateless/transformer.py +++ b/egs/librispeech/ASR/transducer_stateless/transformer.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from subsampling import Conv2dSubsampling, VggSubsampling +from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear from icefall.utils import make_pad_mask @@ -106,7 +106,7 @@ def __init__( # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) ) def forward( From e6a501d3c87222292eb83f0a2a158835e85606ba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 11:52:13 +0800 Subject: [PATCH 070/185] Add max-abs-value constraint in DerivBalancer --- .../ASR/conformer_ctc/subsampling.py | 42 +++++++++++++------ .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6bf0aefe4d..ea02041387 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: float = 0.05, - max_factor: float = 0.05, - min_abs: float = 0.5) -> Tensor: + threshold: float, # e.g. 0.05 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 1000.0 + ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim @@ -336,23 +338,26 @@ def forward(ctx, x: Tensor, proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) - below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) - ctx.save_for_backward(factor, xgt0, below_threshold) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - factor, xgt0, below_threshold = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) - neg_delta_grad = x_grad.abs() * (factor + too_small_factor) + neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module): than zero is that it will tend to prevent situations where the inputs shrink close to zero and the nonlinearity (e.g. swish) behaves like a linear function and we learn nothing. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. This is to prevent a failure mode where the activations + become so small that the nonlinearity effectively becomes linear, + which makes the module useless and it gets even smaller + to try to "turn it off" completely. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. This is to prevent the possibility of activations getting + out of floating point numerical range (especially in half precision). """ def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.01, - min_abs: float = 0.2): + min_abs: float = 0.2, + max_abs: float = 1000.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor self.min_abs = min_abs + self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.min_abs) + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwish(torch.nn.Module): diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c2202fe1ec..1434d6da4d 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5d69acb25b45d80e554c55d8dbc0aacc3432217a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 13:15:20 +0800 Subject: [PATCH 071/185] Add max-abs-value --- .../ASR/conformer_ctc/subsampling.py | 52 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 6 +-- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ea02041387..8d01d8fc04 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,14 +47,12 @@ def __init__(self, idim: int, odim: int) -> None: ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -325,7 +323,8 @@ class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: float, # e.g. 0.05 + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 max_factor: float, # e.g. 0.01 min_abs: float, # e.g. 0.2 max_abs: float, # e.g. 1000.0 @@ -336,7 +335,13 @@ def forward(ctx, x: Tensor, sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor = (threshold - proportion_positive).relu() * (max_factor / threshold) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) below_threshold = (mean_abs < min_abs) @@ -348,16 +353,14 @@ def forward(ctx, x: Tensor, return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) - - - return x_grad - neg_delta_grad, None, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -516,7 +519,9 @@ class DerivBalancer(torch.nn.Module): Args: channel_dim: the dimension/axi corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - threshold: the threshold, per channel, of the proportion of the time + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_factor: the maximum factor by which we modify the derivatives, e.g. with max_factor=0.02, the the derivatives would be multiplied by @@ -538,19 +543,22 @@ class DerivBalancer(torch.nn.Module): out of floating point numerical range (especially in half precision). """ def __init__(self, channel_dim: int, - threshold: float = 0.05, + min_positive: float = 0.05, + max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 1000.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim - self.threshold = threshold + self.min_positive = min_positive + self.max_positive = max_positive self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, + return DerivBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, self.max_factor, self.min_abs, self.max_abs) @@ -600,14 +608,14 @@ def _test_exp_scale_relu(): def _test_deriv_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) - N = 500 + N = 1000 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) - y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) @@ -618,14 +626,16 @@ def _test_deriv_balancer_sign(): def _test_deriv_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) - N = 500 - x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + m = DerivBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 9dd6bae4d4..3516c2205b 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -158,8 +158,7 @@ def __init__( self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=-1), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -167,8 +166,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=-1), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), From 97c0bb82d329426d80d535348651bceaab58df1c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 13:19:20 +0800 Subject: [PATCH 072/185] Change dir name --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 1434d6da4d..897cf54113 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From f351777e9cc0da74e96212782f9056057b2407a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 17:29:39 +0800 Subject: [PATCH 073/185] Remove ExpScale in feedforward layes. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 13 +++++++++---- .../ASR/transducer_stateless/conformer.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 8d01d8fc04..04481aa5bd 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -565,8 +565,13 @@ def forward(self, x: Tensor) -> Tensor: class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x - 1.0) + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient + backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1) + """ + x1 = x - 1.0 + s = torch.sigmoid(x1) + return (x1 * s) + s # (x-1) * s + s == x * s def _test_exp_scale_swish(): @@ -581,10 +586,10 @@ def _test_exp_scale_swish(): y1 = m1(x1) y2 = m2(x2) - assert torch.allclose(y1, y2) + assert torch.allclose(y1, y2, atol=1e-05) y1.sum().backward() y2.sum().backward() - assert torch.allclose(x1.grad, x2.grad) + assert torch.allclose(x1.grad, x2.grad, atol=1e-05) def _test_exp_scale_relu(): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3516c2205b..e6466d8e61 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -159,7 +159,7 @@ def __init__( self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1), - SwishExpScale(dim_feedforward, speed=20.0), + DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) @@ -167,7 +167,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1), - SwishExpScale(dim_feedforward, speed=20.0), + DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 897cf54113..994b89e49a 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 437e8b208341bf027744be5d81f0126635150572 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 23:31:08 +0800 Subject: [PATCH 074/185] Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. --- .../ASR/conformer_ctc/subsampling.py | 4 ++-- .../ASR/transducer_stateless/conformer.py | 22 ++++++++++++++++++- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 04481aa5bd..3a1eda3f15 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -327,7 +327,7 @@ def forward(ctx, x: Tensor, max_positive: float, # e.g. 0.95 max_factor: float, # e.g. 0.01 min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 1000.0 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: @@ -547,7 +547,7 @@ def __init__(self, channel_dim: int, max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, - max_abs: float = 1000.0): + max_abs: float = 100.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index e6466d8e61..65a8431dee 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -847,6 +847,22 @@ def __init__( padding=0, bias=bias, ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.depthwise_conv = ScaledConv1d( channels, channels, @@ -857,6 +873,8 @@ def __init__( bias=bias, ) + + self.deriv_balancer2 = DerivBalancer(channel_dim=1) # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() @@ -885,12 +903,14 @@ def forward(self, x: Tensor) -> Tensor: # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) - # TODO: can have a learned scale in here, or a fixed one. + x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 994b89e49a..a0395a3988 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From ae2568825396f63b6c3a68eb3e8d6e132d407da9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Mar 2022 11:02:32 +0800 Subject: [PATCH 075/185] Make DoubleSwish more memory efficient --- .../ASR/conformer_ctc/subsampling.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 04481aa5bd..6ff9be4e61 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -267,23 +267,6 @@ def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: return (x * (scale * speed).exp()).relu() -class ExpScaleReluFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_swish(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - class ExpScaleReluFunction(torch.autograd.Function): @@ -563,16 +546,32 @@ def forward(self, x: Tensor) -> Tensor: self.max_abs) +def _double_swish(x: Tensor) -> Tensor: + # double-swish, implemented/approximated as offset-swish + return x * torch.sigmoid(x - 1.0) + +class DoubleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + ctx.save_for_backward(x.detach()) + return _double_swish(x) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + # TODO: can make this more efficient. + x, = ctx.saved_tensors + x.requires_grad = True + with torch.enable_grad(): + y = _double_swish(x) + y.backward(gradient=y_grad) + return x.grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient - backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1) + that we approximate closely with x * sigmoid(x-1). """ - x1 = x - 1.0 - s = torch.sigmoid(x1) - return (x1 * s) + s # (x-1) * s + s == x * s - + return DoubleSwishFunction.apply(x) def _test_exp_scale_swish(): From 8d17a05dd29ef78cd6063722f3f7bb2d92f8ad0e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Mar 2022 19:23:33 +0800 Subject: [PATCH 076/185] Reduce constraints from deriv-balancer in ConvModule. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 +++----- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 65a8431dee..07fe934aea 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -861,7 +861,8 @@ def __init__( # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.0, max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -873,8 +874,6 @@ def __init__( bias=bias, ) - - self.deriv_balancer2 = DerivBalancer(channel_dim=1) # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() @@ -904,13 +903,12 @@ def forward(self, x: Tensor) -> Tensor: # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = self.deriv_balancer1(x) + x = self.deriv_balancer(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a0395a3988..f2d89b099a 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a23010fc1066a791966b0244831f3bb744751587 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Mar 2022 23:04:51 +0800 Subject: [PATCH 077/185] Add warmup mode --- .../ASR/transducer_stateless/conformer.py | 47 +++++++------------ .../transducer_stateless/encoder_interface.py | 4 +- .../ASR/transducer_stateless/model.py | 3 +- .../ASR/transducer_stateless/train.py | 11 +++-- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 07fe934aea..b68aced9fa 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -88,7 +88,7 @@ def __init__( def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -112,7 +112,8 @@ def forward( assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask, + warmup_mode=warmup_mode) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -258,7 +259,6 @@ def __init__( self.num_layers = num_layers num_channels = encoder_layer.d_model self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0) @@ -269,6 +269,7 @@ def forward( pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup_mode: bool = False ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -300,7 +301,7 @@ def forward( if i in self.aux_layers: outputs.append(output) - output = self.combiner(outputs) + output = self.combiner(outputs, warmup_mode) return output @@ -946,17 +947,12 @@ class RandomCombine(torch.nn.Module): is a random combination of all the inputs; but which in test time will be just the last input. - All but the last input will have a linear transform before we - randomly combine them; these linear transforms will be initialzed - to the identity transform. - The idea is that the list of Tensors will be a list of outputs of multiple conformer layers. This has a similar effect as iterated loss. (See: DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER NETWORKS). """ def __init__(self, num_inputs: int, - num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0) -> None: @@ -965,7 +961,6 @@ def __init__(self, num_inputs: int, num_inputs: The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing continuous layer weights. @@ -991,8 +986,6 @@ def __init__(self, num_inputs: int, assert pure_prob >= 0 and pure_prob <= 1 assert final_weight > 0 and final_weight < 1 assert num_inputs >= 1 - self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1)]) self.num_inputs = num_inputs self.final_weight = final_weight @@ -1000,14 +993,10 @@ def __init__(self, num_inputs: int, self.stddev= stddev self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) - def forward(self, inputs: Sequence[Tensor]) -> Tensor: + def forward(self, inputs: Sequence[Tensor], + warmup_mode: bool) -> Tensor: """ Forward function. Args: @@ -1019,24 +1008,18 @@ def forward(self, inputs: Sequence[Tensor]) -> Tensor: """ num_inputs = self.num_inputs assert len(inputs) == num_inputs - if not self.training: + if not (self.training and warmup_mode): return inputs[-1] # Shape of weights: (*, num_inputs) num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) - mod_inputs.append(inputs[num_inputs - 1]) - - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) # weights: (num_frames, num_inputs) weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, @@ -1118,12 +1101,14 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") num_inputs = 3 num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, - final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) + m = RandomCombine(num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev) x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - y = m(x) + y = m(x, True) assert y.shape == x[0].shape assert torch.allclose(y, x[0]) # .. since actually all ones. diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index 257facce4f..b295ce94bc 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -32,6 +32,8 @@ def forward( x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup_mode: for training only, if true then train in + "warmup mode" (use this for the first few thousand minibatches). Returns: Return a tuple containing two tensors: - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 17b5f63e58..a45f0e2958 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -62,6 +62,7 @@ def forward( x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + warmup_mode: bool = False ) -> torch.Tensor: """ Args: @@ -82,7 +83,7 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) + encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index f2d89b099a..6c318c242c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -203,6 +203,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 + "warmup_minibatches": 3000, # use warmup mode for 3k minibatches. # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -360,6 +361,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + is_warmup_mode: bool = False ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -391,7 +393,8 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model(x=feature, x_lens=feature_lens, y=y, + warmup_mode=is_warmup_mode) assert loss.requires_grad == is_training @@ -423,6 +426,7 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, + is_warmup_mode=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -484,6 +488,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + is_warmup_mode=(params.batch_idx_train Date: Tue, 15 Mar 2022 13:10:35 +0800 Subject: [PATCH 078/185] Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- egs/librispeech/ASR/transducer_stateless/conformer.py | 10 ++++++---- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 1e31c0a208..7c2b1ec045 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -527,7 +527,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, min_positive: float = 0.05, - max_positive: float = 0.95, + max_positive: float = 1.0, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index b68aced9fa..54729652b2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -862,8 +862,7 @@ def __init__( # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.0, max_positive=1.0) + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) self.depthwise_conv = ScaledConv1d( channels, @@ -875,7 +874,9 @@ def __init__( bias=bias, ) - # shape: (channels, 1), broadcasts with (batch, channel, time). + self.deriv_balancer2 = DerivBalancer(channel_dim=1) + + # Shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() self.pointwise_conv2 = ScaledConv1d( @@ -904,12 +905,13 @@ def forward(self, x: Tensor) -> Tensor: # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = self.deriv_balancer(x) + x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) + x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 6c318c242c..6408290b41 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 21ebd356e78b82a93485554a81402a2149874eb4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 13:49:15 +0800 Subject: [PATCH 079/185] Add some extra info to diagnostics --- .../ASR/transducer_stateless/diagnostics.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 8ea35582a9..238c50def9 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -79,7 +79,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim options: options object sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value, + stats_type: either "abs" or "positive" or "eigs" or "value", imdictates the type of stats we accumulate, abs is mean absolute value, "positive" is proportion of positive to nonnegative values, "eigs" @@ -129,12 +129,23 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], percentiles.append(stats[index].item()) percentiles = [ '%.2g' % x for x in percentiles ] percentiles = ' '.join(percentiles) - return f'percentiles: [{percentiles}]' + ans = f'percentiles: [{percentiles}]' else: - stats = stats.tolist() - stats = [ '%.2g' % x for x in stats ] - stats = '[' + ' '.join(stats) + ']' - return stats + ans = stats.tolist() + ans = [ '%.2g' % x for x in ans ] + ans = '[' + ' '.join(ans) + ']' + if stats_type == "value": + norm = (stats ** 2).sum().sqrt().item() + mean_abs = stats.abs().mean().item() + # This norm is useful because it is strictly less than the largest + # sqrt(eigenvalue) of the variance, which we print out, and shows, + # speaking in an approximate way, how much of that largest eigenvalue + # can be attributed to the mean of the distribution. + ans += f', norm={norm:.2g}, mean_abs={mean_abs:.2g}' + else: + mean = stats.mean().item() + ans += f', mean={mean:.2g}' + return ans From 1962fe298b713a673cc4fd99c20e1deab45e2560 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 14:35:15 +0800 Subject: [PATCH 080/185] Add deriv-balancer at output of embedding. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 3 +++ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 7c2b1ec045..35de71e43e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -57,6 +57,8 @@ def __init__(self, idim: int, odim: int) -> None: ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) + # constrain mean of output to be close to zero. + self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) self._reset_parameters() def _reset_parameters(self): @@ -84,6 +86,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out_norm(x) + x = self.out_balancer(x) return x diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 6408290b41..488de3ccc3 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From b2abcd721aae06deff49e7535141b9bd58bdf01a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 16:38:19 +0800 Subject: [PATCH 081/185] Add more stats. --- .../ASR/transducer_stateless/diagnostics.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 238c50def9..7fd83d56bc 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -135,16 +135,18 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], ans = [ '%.2g' % x for x in ans ] ans = '[' + ' '.join(ans) + ']' if stats_type == "value": - norm = (stats ** 2).sum().sqrt().item() - mean_abs = stats.abs().mean().item() # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - ans += f', norm={norm:.2g}, mean_abs={mean_abs:.2g}' + norm = (stats ** 2).sum().sqrt().item() + mean = stats.mean().item() + rms = (stats ** 2).mean().sqrt().item() + ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' else: mean = stats.mean().item() - ans += f', mean={mean:.2g}' + rms = (stats ** 2).mean().sqrt().item() + ans += f', mean={mean:.2g}, rms={rms:.2g}' return ans From fc873cc50d7e5a72344b0f081e93802acb441a73 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 17:00:17 +0800 Subject: [PATCH 082/185] Make epsilon in BasicNorm learnable, optionally. --- .../ASR/conformer_ctc/subsampling.py | 44 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 3 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 35de71e43e..78fcac664d 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -56,7 +56,10 @@ def __init__(self, idim: int, odim: int) -> None: DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) - self.out_norm = BasicNorm(odim) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(odim, learn_eps=False) # constrain mean of output to be close to zero. self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) self._reset_parameters() @@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module): So the idea is to introduce this large constant value as an explicit parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. - - We also introduce a learned scaling factor on the output; and we - remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not - that useful unless the LayerNorm immediately follows a nonlinearity). - + doesn't have to do this trick. We make the "eps" learnable. Args: + num_channels: the number of channels, e.g. 512. channel_dim: the axis/dimension corresponding to the channel, interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - initial_eps: the initial "epsilon" that we add as ballast in: + eps: the initial "epsilon" that we add as ballast in: scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with normal LayerNorm. - - speed: a scaling factor that can be interpreted as scaling the learning - rate for this module. CAUTION: the default value of 10.0 intended to be - used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. - If you are using SGD you would probably have to set `speed` to - a value less than one, or the training would be unstable. + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_speed: a constant that determines how fast "eps" learns; + with Adam and variants, this should probably be >= 1, + e.g. 5.0. For SGD and variants, probably a value less than one, + like 0.1, would be suitable, to prevent instability. """ def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25): + eps: float = 0.25, + learn_eps: bool = True, + eps_speed: float = 5.0): super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.eps = eps + self.eps_speed = eps_speed + if learn_eps: + self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + else: + self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + (self.eps * self.eps_speed).exp()) ** -0.5 return x * scales diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 54729652b2..8b229a2345 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -1129,4 +1129,5 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): seq_len = 20 # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64)) + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup_mode=True) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 488de3ccc3..2af306f948 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 261d7602a77ef46626454dce7b7d70b69c79226e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 23:46:53 +0800 Subject: [PATCH 083/185] Draft of 0mean changes.. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- .../ASR/transducer_stateless/conformer.py | 13 +++++++++---- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 78fcac664d..50a9db41ab 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -60,8 +60,8 @@ def __init__(self, idim: int, odim: int) -> None: # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain mean of output to be close to zero. - self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) + # constrain median of output to be close to zero. + self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) self._reset_parameters() def _reset_parameters(self): @@ -536,7 +536,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, min_positive: float = 0.05, - max_positive: float = 1.0, + max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8b229a2345..cc1ae53a18 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -88,7 +88,7 @@ def __init__( def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -179,6 +179,9 @@ def __init__( self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.dropout = nn.Dropout(dropout) @@ -227,7 +230,7 @@ def forward( # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.norm_final(self.pre_norm_final(src)) + src = self.balancer(self.norm_final(self.pre_norm_final(src))) return src @@ -862,7 +865,8 @@ def __init__( # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -874,7 +878,8 @@ def __init__( bias=bias, ) - self.deriv_balancer2 = DerivBalancer(channel_dim=1) + self.deriv_balancer2 = DerivBalancer(channel_dim=1, + min_positive=0.05, max_positive=1.0) # Shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2af306f948..41fdb4ef3b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 633213424d24de73c09170d68b138ef830ed3cbd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:42:59 +0800 Subject: [PATCH 084/185] Rework of initialization --- .../ASR/conformer_ctc/subsampling.py | 70 ++++++++++++++++--- .../ASR/transducer_stateless/conformer.py | 16 ++--- .../ASR/transducer_stateless/decoder.py | 64 +++-------------- .../ASR/transducer_stateless/train.py | 3 +- 4 files changed, 78 insertions(+), 75 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 50a9db41ab..5e44c5b297 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -62,13 +62,6 @@ def __init__(self, idim: int, odim: int) -> None: self.out_norm = BasicNorm(odim, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) - self._reset_parameters() - - def _reset_parameters(self): - # init weights with smaller than default variance, because otherwise - # they learn too slowly in relative terms (assuming we're training with adam). - nn.init.normal_(self.conv[0].weight, std=0.05) - nn.init.constant_(self.conv[0].bias, 0.0) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -406,8 +399,36 @@ def forward(self, x: Tensor) -> Tensor: return x * scales + + class ScaledLinear(nn.Linear): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * (self.weight_scale * self.scale_speed).exp() + bias = self.bias * (self.bias_scale * self.scale_speed).exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + scale_speed: a factor that affects how fast the weight_scale + and bias_scale learn; this value is suitable for Adam-type + optimizers. + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + + Note: it uses the default initialization for the weight and bias, + inherited from nn.Linear. For modules with small fan-in, this + may be larger than optimal. + """ + def __init__(self, *args, + scale_speed: float = 5.0, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = (torch.tensor(initial_scale).log() / scale_speed) self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -417,6 +438,17 @@ def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self): + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -425,7 +457,6 @@ def get_bias(self): return (None if self.bias is None else self.bias * (self.bias_scale * self.scale_speed).exp()) - def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) @@ -442,6 +473,17 @@ def __init__(self, *args, scale_speed = 5.0, self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -471,6 +513,16 @@ def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) def get_weight(self): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cc1ae53a18..0b89fdcd21 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ def __init__( DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( @@ -170,7 +170,7 @@ def __init__( DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -423,7 +423,7 @@ def __init__( ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -434,7 +434,6 @@ def __init__( self.scale_speed = scale_speed self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() def _pos_bias_u(self): @@ -444,12 +443,8 @@ def _pos_bias_v(self): return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) + nn.init.normal_(self.pos_bias_u, std=0.05) + nn.init.normal_(self.pos_bias_v, std=0.05) def forward( self, @@ -891,7 +886,6 @@ def __init__( stride=1, padding=0, bias=bias, - initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index bc4bcb3f63..838b6794d6 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -183,7 +183,7 @@ class ScaledEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None, + sparse: bool = False, scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings @@ -198,19 +198,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) - - if _weight is None: - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - else: - assert list(_weight.shape) == [num_embeddings, embedding_dim], \ - 'Shape of weight does not match num_embeddings and embedding_dim' - self.weight = nn.Parameter(_weight) + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) @@ -228,7 +227,6 @@ def forward(self, input: Tensor) -> Tensor: None, 2.0, # None, 2.0 relates to normalization self.scale_grad_by_freq, self.sparse) - def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: @@ -238,45 +236,3 @@ def extra_repr(self) -> str: if self.sparse is not False: s += ', sparse=True' return s.format(**self.__dict__) - - @classmethod - def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, - sparse=False): - r"""Creates Embedding instance from given 2-dimensional FloatTensor. - - Args: - embeddings (Tensor): FloatTensor containing weights for the Embedding. - First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. - freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. - Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` - padding_idx (int, optional): See module initialization documentation. - max_norm (float, optional): See module initialization documentation. - norm_type (float, optional): See module initialization documentation. Default ``2``. - scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. - sparse (bool, optional): See module initialization documentation. - - Examples:: - - >>> # FloatTensor containing pretrained weights - >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) - >>> embedding = nn.Embedding.from_pretrained(weight) - >>> # Get embeddings for index 1 - >>> input = torch.LongTensor([1]) - >>> embedding(input) - tensor([[ 4.0000, 5.1000, 6.3000]]) - """ - assert embeddings.dim() == 2, \ - 'Embeddings parameter is expected to be 2-dimensional' - rows, cols = embeddings.shape - embedding = cls( - num_embeddings=rows, - embedding_dim=cols, - _weight=embeddings, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - embedding.weight.requires_grad = not freeze - return embedding diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 41fdb4ef3b..8f21577151 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,8 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", + # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. + default="transducer_stateless/randcombine1_expscale3_rework2d" help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a783b9646729954623c37b932431ad0df0c253e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:43:44 +0800 Subject: [PATCH 085/185] Fix typo --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 8f21577151..1190522e7e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -111,7 +111,7 @@ def get_parser(): "--exp-dir", type=str, # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. - default="transducer_stateless/randcombine1_expscale3_rework2d" + default="transducer_stateless/randcombine1_expscale3_rework2d", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 00be56c7a0ef956a7790598e73cf19e9ce6086cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:49:00 +0800 Subject: [PATCH 086/185] Remove dead code --- .../ASR/transducer_stateless/conformer.py | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 0b89fdcd21..cafc04ed19 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -876,8 +876,7 @@ def __init__( self.deriv_balancer2 = DerivBalancer(channel_dim=1, min_positive=0.05, max_positive=1.0) - # Shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = SwishOffset() + self.activation = DoubleSwish() self.pointwise_conv2 = ScaledConv1d( channels, @@ -918,24 +917,6 @@ def forward(self, x: Tensor) -> Tensor: return x.permute(2, 0, 1) -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - -class SwishOffset(torch.nn.Module): - """Construct an SwishOffset object.""" - def __init__(self, offset: float = -1.0) -> None: - super(SwishOffset, self).__init__() - self.offset = offset - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x + self.offset) - - class Identity(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return x From 0e9cad3f1f62abda43c6b218917525142c32b3d3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 14:42:53 +0800 Subject: [PATCH 087/185] Modifying initialization from normal->uniform; add initial_scale when initializing --- .../ASR/conformer_ctc/subsampling.py | 17 +++++++++++------ .../ASR/transducer_stateless/conformer.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 5e44c5b297..6cc90c8a1d 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -441,15 +441,16 @@ def __init__(self, *args, self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] + fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) - def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -476,7 +477,9 @@ def __init__(self, *args, scale_speed = 5.0, self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() @@ -516,10 +519,12 @@ def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cafc04ed19..0832d9385f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ def __init__( DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( @@ -170,7 +170,7 @@ def __init__( DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -423,7 +423,7 @@ def __init__( ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -885,6 +885,7 @@ def __init__( stride=1, padding=0, bias=bias, + initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: From 6561743d7b454111582011936beb0aa09f8fa161 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 14:55:17 +0800 Subject: [PATCH 088/185] bug fix re sqrt --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6cc90c8a1d..7c7d0ee6c8 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -442,7 +442,7 @@ def __init__(self, *args, def _reset_parameters(self): std = 0.05 - a = math.sqrt(3) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) @@ -478,7 +478,7 @@ def __init__(self, *args, scale_speed = 5.0, def _reset_parameters(self): std = 0.05 - a = math.sqrt(3) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) @@ -520,7 +520,7 @@ def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): def _reset_parameters(self): std = 0.05 - a = math.sqrt(3) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) From c82db4184a395177f4c2a79f1f20d7d3508777b2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 15:50:11 +0800 Subject: [PATCH 089/185] Remove xscale from pos_embedding --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 7c7d0ee6c8..867ababf23 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -449,7 +449,7 @@ def _reset_parameters(self): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -485,7 +485,7 @@ def _reset_parameters(self): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): @@ -527,7 +527,7 @@ def _reset_parameters(self): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 0832d9385f..b14e837805 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -327,7 +327,6 @@ def __init__( """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -379,7 +378,6 @@ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: """ self.extend_pe(x) - x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 From dfc75752c40c931eb63385e793d1ababf0e02489 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 18:06:01 +0800 Subject: [PATCH 090/185] Remove some dead code. --- .../ASR/conformer_ctc/subsampling.py | 160 ------------------ .../ASR/transducer_stateless/conformer.py | 2 +- 2 files changed, 1 insertion(+), 161 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 867ababf23..500cacca86 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -174,130 +174,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class PeLUFunction(torch.autograd.Function): - """ - Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)). - The function is: - x.relu() + alpha * (cutoff - x).relu() - E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off - of neurons. - """ - @staticmethod - def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor: - mask1 = (x >= 0) # >=, so there is deriv if x == 0. - p = cutoff - x - mask2 = (p >= 0) - ctx.save_for_backward(mask1, mask2) - ctx.alpha = alpha - return x.relu() + alpha * p.relu() - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]: - mask1, mask2 = ctx.saved_tensors - return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None - - - -class PeLU(torch.nn.Module): - def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None: - super(PeLU, self).__init__() - self.cutoff = cutoff - self.alpha = alpha - def forward(self, x: Tensor) -> Tensor: - return PeLUFunction.apply(x, self.cutoff, self.alpha) - -class ExpScale(torch.nn.Module): - def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0): - super(ExpScale, self).__init__() - scale = torch.tensor(initial_scale) - scale = scale.log() / speed - self.scale = nn.Parameter(scale.detach()) - self.speed = speed - - def forward(self, x: Tensor) -> Tensor: - return x * (self.scale * self.speed).exp() - - - -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - # double-swish, implemented/approximated as offset-swish - x = (x * torch.sigmoid(x - 1.0)) - x = x * (scale * speed).exp() - return x - -class SwishExpScaleFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_swish(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - - -class SwishExpScale(torch.nn.Module): - # combines ExpScale and a Swish (actually the ExpScale is after the Swish). - # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) - # - def __init__(self, *shape, speed: float = 1.0): - super(SwishExpScale, self).__init__() - - initial_log_scale = torch.zeros(()).detach() - self.scale = nn.Parameter(initial_log_scale) - self.speed = speed - - def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed) - # x = (x * torch.sigmoid(x)) - # x = (x * torch.sigmoid(x)) - # x = x * (self.scale * self.speed).exp() - # return x - - - -def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * (scale * speed).exp()).relu() - - - - -class ExpScaleReluFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_relu(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_relu(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - -class ExpScaleRelu(torch.nn.Module): - # combines ExpScale and Relu. - # caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0) - def __init__(self, *shape, speed: float = 1.0): - super(ExpScaleRelu, self).__init__() - self.scale = nn.Parameter(torch.zeros(*shape)) - self.speed = speed - - def forward(self, x: Tensor) -> Tensor: - return ExpScaleReluFunction.apply(x, self.scale, self.speed) - # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() - # return x * (self.scale * self.speed).exp() - @@ -639,40 +515,6 @@ def forward(self, x: Tensor) -> Tensor: """ return DoubleSwishFunction.apply(x) -def _test_exp_scale_swish(): - - x1 = torch.randn(50, 60).detach() - x2 = x1.detach() - - m1 = SwishExpScale(50, 1, speed=4.0) - m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) - x1.requires_grad = True - x2.requires_grad = True - - y1 = m1(x1) - y2 = m2(x2) - assert torch.allclose(y1, y2, atol=1e-05) - y1.sum().backward() - y2.sum().backward() - assert torch.allclose(x1.grad, x2.grad, atol=1e-05) - -def _test_exp_scale_relu(): - - x1 = torch.randn(50, 60).detach() - x2 = x1.detach() - - m1 = ExpScaleRelu(50, 1, speed=4.0) - m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0)) - x1.requires_grad = True - x2.requires_grad = True - - y1 = m1(x1) - y2 = m2(x2) - assert torch.allclose(y1, y2) - y1.sum().backward() - y2.sum().backward() - assert torch.allclose(x1.grad, x2.grad) - def _test_deriv_balancer_sign(): @@ -737,6 +579,4 @@ def _test_basic_norm(): if __name__ == '__main__': _test_deriv_balancer_sign() _test_deriv_balancer_magnitude() - _test_exp_scale_swish() - _test_exp_scale_relu() _test_basic_norm() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index b14e837805..8de02628d3 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn From e838c192ef05b7a4a3659672cfa54ef37f23f57b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 19:27:45 +0800 Subject: [PATCH 091/185] Cosmetic changes/renaming things --- .../ASR/conformer_ctc/subsampling.py | 59 ++++++++----------- .../ASR/transducer_stateless/conformer.py | 20 ++++--- 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 500cacca86..0a39b0f336 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,12 +47,12 @@ def __init__(self, idim: int, odim: int) -> None: ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1), + ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1), + ActivationBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -61,7 +61,9 @@ def __init__(self, idim: int, odim: int) -> None: # needed. self.out_norm = BasicNorm(odim, learn_eps=False) # constrain median of output to be close to zero. - self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -177,7 +179,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: -class DerivBalancerFunction(torch.autograd.Function): +class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, @@ -428,44 +430,33 @@ def forward(self, input: Tensor) -> Tensor: -class DerivBalancer(torch.nn.Module): +class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for each channel, that it is positive at least a proportion `threshold` of the time. It does this by multiplying negative derivative values by up to (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 0 at the threshold to those extremal values when none + interpolated from 1 at the threshold to those extremal values when none of the inputs are positive. - When all grads are zero for a channel, this - module sets all the input derivatives for that channel to -epsilon; the - idea is to bring completely dead neurons back to life this way. Args: - channel_dim: the dimension/axi corresponding to the channel, e.g. + channel_dim: the dimension/axis corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. min_positive: the minimum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_positive: the maximum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives, + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.01]. - zero: we use this value in the comparison (x > 0), i.e. we actually use - (x > zero). The reason for using a threshold slightly greater - than zero is that it will tend to prevent situations where the - inputs shrink close to zero and the nonlinearity (e.g. swish) - behaves like a linear function and we learn nothing. + values in the range [0.98..1.02]. min_abs: the minimum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent - this. This is to prevent a failure mode where the activations - become so small that the nonlinearity effectively becomes linear, - which makes the module useless and it gets even smaller - to try to "turn it off" completely. + this. max_abs: the maximum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent - this. This is to prevent the possibility of activations getting - out of floating point numerical range (especially in half precision). + this. """ def __init__(self, channel_dim: int, min_positive: float = 0.05, @@ -473,7 +464,7 @@ def __init__(self, channel_dim: int, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): - super(DerivBalancer, self).__init__() + super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive self.max_positive = max_positive @@ -482,10 +473,10 @@ def __init__(self, channel_dim: int, self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return DerivBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) def _double_swish(x: Tensor) -> Tensor: @@ -524,8 +515,8 @@ def _test_deriv_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude(): x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8de02628d3..6278734e58 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -159,7 +159,7 @@ def __init__( self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1), + ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -167,7 +167,7 @@ def __init__( self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1), + ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -180,7 +180,9 @@ def __init__( self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) self.dropout = nn.Dropout(dropout) @@ -858,8 +860,9 @@ def __init__( # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, max_positive=1.0) + self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, + max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -871,8 +874,9 @@ def __init__( bias=bias, ) - self.deriv_balancer2 = DerivBalancer(channel_dim=1, - min_positive=0.05, max_positive=1.0) + self.deriv_balancer2 = ActivationBalancer(channel_dim=1, + min_positive=0.05, + max_positive=1.0) self.activation = DoubleSwish() From 1f3a15f3c45814daefbf399d4a181b91af7cd8ea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 22:14:30 +0800 Subject: [PATCH 092/185] Start adding some files.. --- egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py | 0 .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py | 1 + 3 files changed, 2 insertions(+) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py new file mode 120000 index 0000000000..07f39b4511 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -0,0 +1 @@ +../transducer/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py new file mode 120000 index 0000000000..227d2247c0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/beam_search.py \ No newline at end of file From cc8e4412f7954620224a5b2f4deb80e029ce7c36 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 22:16:40 +0800 Subject: [PATCH 093/185] Add more files.. --- .../pruned_transducer_stateless2/conformer.py | 1115 +++++++++++++++++ .../pruned_transducer_stateless2/decode.py | 1 + 2 files changed, 1116 insertions(+) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py new file mode 100644 index 0000000000..bf96b41f97 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import Optional, Tuple, Sequence +from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d + +import torch +from torch import Tensor, nn +from transformer import Transformer + +from icefall.utils import make_pad_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + output_dim (int): Number of output dimension + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + output_dim: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + aux_layer_period: int = 3 + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + output_dim=output_dim, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, + aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.normalize_before = normalize_before + + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + mask = make_pad_mask(lengths) + + x = self.encoder(x, pos_emb, src_key_padding_mask=mask, + warmup_mode=warmup_mode) # (T, N, C) + + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + + self.pre_norm_final = Identity() + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + self.dropout = nn.Dropout(dropout) + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout(self.conv_module(src)) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.balancer(self.norm_final(self.pre_norm_final(src))) + + return src + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + aux_layers: Sequence[int]) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.aux_layers = set(aux_layers + [num_layers - 1]) + assert num_layers - 1 not in aux_layers + self.num_layers = num_layers + num_channels = encoder_layer.d_model + self.combiner = RandomCombine(num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup_mode: bool = False + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + outputs = [] + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + if i in self.aux_layers: + outputs.append(output) + + output = self.combiner(outputs, warmup_mode) + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + scale_speed: float = 5.0 + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.scale_speed = scale_speed + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + + def _pos_bias_v(self): + return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.05) + nn.init.normal_(self.pos_bias_v, std=0.05) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, + max_positive=1.0) + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer(channel_dim=1, + min_positive=0.05, + max_positive=1.0) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25 + ) + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Identity(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return x + + +class RandomCombine(torch.nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + def __init__(self, num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0) -> None: + """ + Args: + num_inputs: The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, + or combinations of layers, to use, is conceptually as follows. + With probability `pure_prob`: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super(RandomCombine, self).__init__() + assert pure_prob >= 0 and pure_prob <= 1 + assert final_weight > 0 and final_weight < 1 + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev= stddev + + self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() + + + def forward(self, inputs: Sequence[Tensor], + warmup_mode: bool) -> Tensor: + """ + Forward function. + Args: + inputs: a list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + a Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not (self.training and warmup_mode): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, + num_frames) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + + def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: + """ + Return a tensor of random weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), such that + ans.sum(dim=1) is all ones. + + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) + + def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with + exactly one weight equal to 1.0 on each frame. + """ + + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + + indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, + final, nonfinal) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) + return ans + + + def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that + sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. + """ + logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev + logprobs[:,-1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") + num_inputs = 3 + num_channels = 50 + m = RandomCombine(num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev) + + x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] + + y = m(x, True) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +if __name__ == '__main__': + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c(torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup_mode=True) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py new file mode 120000 index 0000000000..c1125a9bae --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode.py \ No newline at end of file From e3ad8f63e73e8cc6a1970d451285def55e97a776 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 22:22:10 +0800 Subject: [PATCH 094/185] update decode.py file type --- .../pruned_transducer_stateless2/decode.py | 424 +++++++++++++++++- 1 file changed, 423 insertions(+), 1 deletion(-) mode change 120000 => 100755 egs/librispeech/ASR/pruned_transducer_stateless2/decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py deleted file mode 120000 index c1125a9bae..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless/decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py new file mode 100755 index 0000000000..86ec6172fd --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import beam_search, greedy_search, modified_beam_search +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() From 11bea4513eff9b478df4ad02009fd0f491dd7ca5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 11:17:52 +0800 Subject: [PATCH 095/185] Add remaining files in pruned_transducer_stateless2 --- .../pruned_transducer_stateless2/conformer.py | 2 +- .../pruned_transducer_stateless2/decoder.py | 241 ++++++ .../encoder_interface.py | 1 + .../pruned_transducer_stateless2/export.py | 182 ++++ .../pruned_transducer_stateless2/joiner.py | 50 ++ .../ASR/pruned_transducer_stateless2/model.py | 170 ++++ .../pruned_transducer_stateless2/scaling.py | 418 +++++++++ .../subsampling.py | 176 ++++ .../ASR/pruned_transducer_stateless2/train.py | 810 ++++++++++++++++++ .../transformer.py | 418 +++++++++ .../transducer_stateless/encoder_interface.py | 4 +- 11 files changed, 2468 insertions(+), 4 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless2/export.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/model.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless2/train.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bf96b41f97..245af05e37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -19,7 +19,7 @@ import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py new file mode 100644 index 0000000000..7836ca9992 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -0,0 +1,241 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Optional +from scaling import ScaledConv1d, ScaledLinear + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + embedding_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + embedding_dim: + Dimension of the input embedding. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + self.embedding = ScaledEmbedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + padding_idx=blank_id, + ) + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = ScaledConv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + self.output_linear = ScaledLinear(embedding_dim, vocab_size) + + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, embedding_dim). + """ + y = y.to(torch.int64) + embedding_out = self.embedding(y) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = self.output_linear(F.relu(embedding_out)) + return embedding_out + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale_speed = scale_speed + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = (self.scale * self.scale_speed).exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py new file mode 120000 index 0000000000..aa5d0217a8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py new file mode 100755 index 0000000000..7d2a07817c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless/export.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless/decode.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py new file mode 100644 index 0000000000..61bfe81867 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -0,0 +1,50 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scaling import ScaledLinear + +class Joiner(nn.Module): + def __init__(self, input_dim: int, inner_dim: int, output_dim: int): + super().__init__() + + self.inner_linear = ScaledLinear(input_dim, inner_dim) + self.output_linear = ScaledLinear(inner_dim, output_dim) + + def forward( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape == decoder_out.shape + + logit = encoder_out + decoder_out + + logit = self.inner_linear(torch.tanh(logit)) + + output = self.output_linear(F.relu(logit)) + + return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py new file mode 100644 index 0000000000..e83d18e3eb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -0,0 +1,170 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, C) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, C). It should contain + one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, C) and (N, U, C). Its + output shape is (N, T, U, C). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup_mode: bool = False + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode=warmup_mode) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, C] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=decoder_out, + am=encoder_out, + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, C] + # lm_pruned : [B, T, prune_range, C] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=encoder_out, lm=decoder_out, ranges=ranges + ) + + # logits : [B, T, prune_range, C] + logits = self.joiner(am_pruned, lm_pruned) + + pruned_loss = k2.rnnt_loss_pruned( + logits=logits, + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py new file mode 100644 index 0000000000..c8bc35fd17 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -0,0 +1,418 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple + + + + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + ) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + xgt0 = x > 0 + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) + + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) + + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.max_factor = max_factor + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + dtype = x_grad.dtype + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + + neg_delta_grad = x_grad.abs() * (factor + scale_factor) + return x_grad - neg_delta_grad, None, None, None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_speed: a constant that determines how fast "eps" learns; + with Adam and variants, this should probably be >= 1, + e.g. 5.0. For SGD and variants, probably a value less than one, + like 0.1, would be suitable, to prevent instability. + """ + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_speed: float = 5.0): + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.eps_speed = eps_speed + if learn_eps: + self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + else: + self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) + + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + (self.eps * self.eps_speed).exp()) ** -0.5 + return x * scales + + + + +class ScaledLinear(nn.Linear): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * (self.weight_scale * self.scale_speed).exp() + bias = self.bias * (self.bias_scale * self.scale_speed).exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + scale_speed: a factor that affects how fast the weight_scale + and bias_scale learn; this value is suitable for Adam-type + optimizers. + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + + Note: it uses the default initialization for the weight and bias, + inherited from nn.Linear. For modules with small fan-in, this + may be larger than optimal. + """ + def __init__(self, *args, + scale_speed: float = 5.0, + initial_scale: float = 1.0, + **kwargs): + super(ScaledLinear, self).__init__(*args, **kwargs) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + self.scale_speed = scale_speed + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + + self._reset_parameters() # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self): + std = 0.05 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) + + +class ScaledConv1d(nn.Conv1d): + def __init__(self, *args, scale_speed = 5.0, + initial_scale=1.0, **kwargs): + super(ScaledConv1d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + std = 0.05 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + + +class ScaledConv2d(nn.Conv2d): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + super(ScaledConv2d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + std = 0.05 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) + + + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + + Args: + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + """ + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): + super(ActivationBalancer, self).__init__() + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + + def forward(self, x: Tensor) -> Tensor: + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) + + +def _double_swish(x: Tensor) -> Tensor: + # double-swish, implemented/approximated as offset-swish + return x * torch.sigmoid(x - 1.0) + +class DoubleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + ctx.save_for_backward(x.detach()) + return _double_swish(x) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + # TODO: can make this more efficient. + x, = ctx.saved_tensors + x.requires_grad = True + with torch.enable_grad(): + y = _double_swish(x) + y.backward(gradient=y_grad) + return x.grad + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + return DoubleSwishFunction.apply(x) + + + +def _test_activation_balancer_sign(): + channel_dim = 0 + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + +def _test_activation_balancer_magnitude(): + channel_dim = 0 + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + + + + +if __name__ == '__main__': + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py new file mode 100644 index 0000000000..51b08e0725 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -0,0 +1,176 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple +from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py new file mode 100755 index 0000000000..51858448d0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -0,0 +1,810 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 300 \ + --lr-factor 1.5 +""" + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from model import Transducer +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall import diagnostics + +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_stateless/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - attention_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 512, + # parameters for Noam + "warm_step": 30000, # For the 100h subset, use 8k + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.vocab_size, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup_mode: bool = False +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup_mode=warmup_mode, + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup_mode=(params.batch_idx_train < params.model_warm_step) + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + params.warm_step = 8000 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup_mode=True # may use slightly more memory + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py new file mode 100644 index 0000000000..3fa847f4f2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -0,0 +1,418 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear + +from icefall.utils import make_pad_mask + + +class Transformer(EncoderInterface): + def __init__( + self, + num_features: int, + output_dim: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + output_dim: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder. + num_encoder_layers: + Number of encoder layers. + dropout: + Dropout in encoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + """ + super().__init__() + + self.num_features = num_features + self.output_dim = output_dim + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) + ) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + + mask = make_pad_mask(lengths) + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index b295ce94bc..3d218dcd04 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -32,8 +32,6 @@ def forward( x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - warmup_mode: for training only, if true then train in - "warmup mode" (use this for the first few thousand minibatches). Returns: Return a tuple containing two tensors: - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) From 13db33ffa2dba26a528748979fa202b6949fc0e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 15:53:53 +0800 Subject: [PATCH 096/185] Fix diagnostics-getting code --- .../ASR/pruned_transducer_stateless2/train.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 51858448d0..b7cd453346 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -115,7 +115,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -556,6 +556,9 @@ def train_one_epoch( optimizer.step() optimizer.zero_grad() + if params.print_diagnostics and batch_idx == 5: + return + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -665,7 +668,11 @@ def run(rank, world_size, args): if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + librispeech = LibriSpeechAsrDataModule(args) From acc0eda5b0b9b20b33ff1cdbb8bb467d6bc9fdbb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 16:09:35 +0800 Subject: [PATCH 097/185] Scale down pruned loss in warmup mode --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b7cd453346..f95d8e73c8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -450,7 +450,9 @@ def compute_loss( lm_scale=params.lm_scale, warmup_mode=warmup_mode, ) - loss = params.simple_loss_scale * simple_loss + pruned_loss + loss = params.simple_loss_scale * simple_loss + if not warmup_mode: + loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0) assert loss.requires_grad == is_training From cbe6b175d1d17bd6e20e2970fba46758249fa11c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 16:46:59 +0800 Subject: [PATCH 098/185] Reduce warmup scale on pruned loss form 0.1 to 0.01. --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index f95d8e73c8..f7eb15c01c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -452,7 +452,7 @@ def compute_loss( ) loss = params.simple_loss_scale * simple_loss if not warmup_mode: - loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0) + loss = loss + (pruned_loss * 0.01 if warmup_mode else pruned_loss) assert loss.requires_grad == is_training From 6769087d702b3b8fed473e2da487772622be26c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:31:25 +0800 Subject: [PATCH 099/185] Remove scale_speed, make swish deriv more efficient. --- .../pruned_transducer_stateless2/conformer.py | 6 +- .../pruned_transducer_stateless2/decoder.py | 138 +---------- .../pruned_transducer_stateless2/scaling.py | 222 ++++++++++++++---- 3 files changed, 181 insertions(+), 185 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 245af05e37..cb46528409 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -410,7 +410,6 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, - scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -430,16 +429,15 @@ def __init__( # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.scale_speed = scale_speed self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() def _pos_bias_u(self): - return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + return self.pos_bias_u * self.pos_bias_u_scale.exp() def _pos_bias_v(self): - return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: nn.init.normal_(self.pos_bias_u, std=0.05) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 7836ca9992..47a519dc9a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from torch import Tensor from typing import Optional -from scaling import ScaledConv1d, ScaledLinear +from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding class Decoder(nn.Module): @@ -103,139 +103,3 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: embedding_out = embedding_out.permute(0, 2, 1) embedding_out = self.output_linear(F.relu(embedding_out)) return embedding_out - - - -class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - - - - def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - scale = (self.scale * self.scale_speed).exp() - if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale - else: - return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c8bc35fd17..f0e3fe1481 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from torch import Tensor -from typing import Tuple +from typing import Tuple, Optional @@ -94,31 +94,25 @@ class BasicNorm(torch.nn.Module): to indicate the connection with conventional LayerNorm. learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. - eps_speed: a constant that determines how fast "eps" learns; - with Adam and variants, this should probably be >= 1, - e.g. 5.0. For SGD and variants, probably a value less than one, - like 0.1, would be suitable, to prevent instability. """ def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. eps: float = 0.25, - learn_eps: bool = True, - eps_speed: float = 5.0): + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.eps_speed = eps_speed if learn_eps: - self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - (self.eps * self.eps_speed).exp()) ** -0.5 + self.eps.exp()) ** -0.5 return x * scales @@ -128,16 +122,13 @@ class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before use, via: - weight = self.weight * (self.weight_scale * self.scale_speed).exp() - bias = self.bias * (self.bias_scale * self.scale_speed).exp() + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() Args: Accepts the standard args and kwargs that nn.Linear accepts e.g. in_features, out_features, bias=False. - scale_speed: a factor that affects how fast the weight_scale - and bias_scale learn; this value is suitable for Adam-type - optimizers. initial_scale: you can override this if you want to increase or decrease the initial magnitude of the module's output (affects the initialization of weight_scale and bias_scale). @@ -149,13 +140,11 @@ class ScaledLinear(nn.Linear): may be larger than optimal. """ def __init__(self, *args, - scale_speed: float = 5.0, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - self.scale_speed = scale_speed if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: @@ -172,14 +161,14 @@ def _reset_parameters(self): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear(input, self.get_weight(), @@ -187,11 +176,10 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -208,15 +196,15 @@ def _reset_parameters(self): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional @@ -230,10 +218,9 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -250,15 +237,15 @@ def _reset_parameters(self): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional @@ -323,6 +310,16 @@ def forward(self, x: Tensor) -> Tensor: self.max_factor, self.min_abs, self.max_abs) +# deriv of double_swish: +# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally +# motivated by its similarity to swish(swish(x), +# where swish(x) = x *sigmoid(x)]. +# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) +# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). +# Now, s'(x) = s(x) * (1-s(x)). +# double_swish'(x) = x * s'(x) + s(x). +# = x * s(x) * (1-s(x)) + s(x). +# = double_swish(x) * (1-s(x)) + s(x) def _double_swish(x: Tensor) -> Tensor: # double-swish, implemented/approximated as offset-swish @@ -331,18 +328,16 @@ def _double_swish(x: Tensor) -> Tensor: class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x.detach()) - return _double_swish(x) + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - # TODO: can make this more efficient. - x, = ctx.saved_tensors - x.requires_grad = True - with torch.enable_grad(): - y = _double_swish(x) - y.backward(gradient=y_grad) - return x.grad + s, y = ctx.saved_tensors + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -353,6 +348,140 @@ def forward(self, x: Tensor) -> Tensor: + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + def _test_activation_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) @@ -409,10 +538,15 @@ def _test_basic_norm(): assert y_rms > 0.5 * x_rms - +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() + _test_double_swish_deriv() From ba3611cefd1af82ef343beec9daef9d2e795f3a0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:35:48 +0800 Subject: [PATCH 100/185] Cosmetic changes to swish --- .../pruned_transducer_stateless2/scaling.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f0e3fe1481..d03bd09676 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -310,22 +310,22 @@ def forward(self, x: Tensor) -> Tensor: self.max_factor, self.min_abs, self.max_abs) -# deriv of double_swish: -# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally -# motivated by its similarity to swish(swish(x), -# where swish(x) = x *sigmoid(x)]. -# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) -# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). -# Now, s'(x) = s(x) * (1-s(x)). -# double_swish'(x) = x * s'(x) + s(x). -# = x * s(x) * (1-s(x)) + s(x). -# = double_swish(x) * (1-s(x)) + s(x) - -def _double_swish(x: Tensor) -> Tensor: - # double-swish, implemented/approximated as offset-swish - return x * torch.sigmoid(x - 1.0) class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() From 2dfcd8f1176851be3a8dbff5c7abde0ef0793cf0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:38:36 +0800 Subject: [PATCH 101/185] Double warm_step --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index f7eb15c01c..ae45db60f7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -615,7 +615,7 @@ def run(rank, world_size, args): params.update(vars(args)) if params.full_libri is False: params.valid_interval = 800 - params.warm_step = 8000 + params.warm_step = 16000 fix_random_seed(params.seed) if world_size > 1: From c9f1aeb7d18eaa33c5d8b7f1fe7365ac9a0ff971 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:40:24 +0800 Subject: [PATCH 102/185] Fix bug with import --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d03bd09676..2d03313122 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -459,6 +459,7 @@ def reset_parameters(self) -> None: self.weight[self.padding_idx].fill_(0) def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: return F.embedding( From 188eada7ac8f761b130f4a3bdbbeb92e8160e38d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 21:28:34 +0800 Subject: [PATCH 103/185] Change initial std from 0.05 to 0.025. --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 2d03313122..d4aef5cdd2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -153,7 +153,7 @@ def __init__(self, *args, self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - std = 0.05 + std = 0.025 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -188,7 +188,7 @@ def __init__(self, *args, self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.05 + std = 0.025 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -229,7 +229,7 @@ def __init__(self, *args, initial_scale=1.0, **kwargs): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.05 + std = 0.025 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: From 8cff994cd7da9880ca63de95212fb1bd7d0a2bc0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 21:30:05 +0800 Subject: [PATCH 104/185] Set also scale for embedding to 0.025. --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d4aef5cdd2..b358e5fa2e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -451,8 +451,9 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log()) + std = 0.025 + nn.init.normal_(self.weight, std=std) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) if self.padding_idx is not None: with torch.no_grad(): From 0ee2404ff09057205812dc0d6b39495192a87c80 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 19 Mar 2022 14:01:45 +0800 Subject: [PATCH 105/185] Remove logging code that broke with newer Lhotse; fix bug with pruned_loss --- .../ASR/pruned_transducer_stateless2/train.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ae45db60f7..851822aaeb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -450,9 +450,8 @@ def compute_loss( lm_scale=params.lm_scale, warmup_mode=warmup_mode, ) - loss = params.simple_loss_scale * simple_loss - if not warmup_mode: - loss = loss + (pruned_loss * 0.01 if warmup_mode else pruned_loss) + loss = (params.simple_loss_scale * simple_loss + + (pruned_loss * 0.01 if warmup_mode else pruned_loss)) assert loss.requires_grad == is_training @@ -687,18 +686,8 @@ def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) valid_cuts = librispeech.dev_clean_cuts() From 05b5e78d8f2298cf6b4b757a620df099dfc0841d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 15:55:11 +0800 Subject: [PATCH 106/185] Add norm+balancer to VggSubsampling --- .../ASR/pruned_transducer_stateless2/subsampling.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index 51b08e0725..c2da23adc8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -158,6 +158,12 @@ def __init__(self, idim: int, odim: int) -> None: self.out = nn.Linear( block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim ) + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -173,4 +179,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layers(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) return x From ccbf8ba0862347007fb6aed87fff6f152d1bc35f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 16:51:48 +0800 Subject: [PATCH 107/185] Incorporate changes from master into pruned_transducer_stateless2. --- .../pruned_transducer_stateless2/decode.py | 175 ++++++++++++++---- .../pruned_transducer_stateless2/decoder.py | 1 + .../ASR/pruned_transducer_stateless2/train.py | 121 ++++++++++-- icefall/diagnostics.py | 9 +- 4 files changed, 254 insertions(+), 52 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 86ec6172fd..ad76411c04 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -42,6 +42,17 @@ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -49,16 +60,26 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + modified_beam_search, +) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -88,6 +109,17 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -110,6 +142,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -117,8 +150,35 @@ def get_parser(): "--beam-size", type=int, default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + fast_beam_search""", ) parser.add_argument( @@ -144,6 +204,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -166,6 +227,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -184,36 +248,62 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -221,6 +311,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -233,6 +324,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -260,6 +354,7 @@ def decode_dataset( params=params, model=model, sp=sp, + decoding_graph=decoding_graph, batch=batch, ) @@ -340,12 +435,17 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "fast_beam_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -372,7 +472,12 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg == 1: + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 @@ -388,6 +493,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -408,6 +518,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 47a519dc9a..13e45e03b7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -64,6 +64,7 @@ def __init__( assert context_size >= 1, context_size self.context_size = context_size + self.vocab_size = vocab_size if context_size > 1: self.conv = ScaledConv1d( in_channels=embedding_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 851822aaeb..d28a8a060a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -36,7 +36,7 @@ import logging from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import k2 import sentencepiece as spm @@ -48,6 +48,7 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor @@ -55,8 +56,9 @@ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall import diagnostics @@ -112,6 +114,15 @@ def get_parser(): """, ) + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -192,6 +203,30 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + return parser @@ -320,15 +355,16 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: +) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: @@ -338,20 +374,22 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. Returns: - Return None. + Return a dict containing previously saved training info. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( filename, model=model, optimizer=optimizer, - scheduler=scheduler, ) keys = [ @@ -360,10 +398,13 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", + "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] + params["start_epoch"] = saved_params["cur_epoch"] + return saved_params @@ -371,7 +412,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -381,6 +422,10 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. """ if rank != 0: return @@ -390,7 +435,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, - scheduler=scheduler, + sampler=sampler, rank=rank, ) @@ -509,6 +554,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -531,12 +577,21 @@ def train_one_epoch( Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. """ model.train() tot_loss = MetricsTracker() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -560,6 +615,27 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -688,7 +764,14 @@ def remove_short_and_long_utt(c: Cut): train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = librispeech.train_dataloaders(train_cuts) + if checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() @@ -728,6 +811,7 @@ def remove_short_and_long_utt(c: Cut): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) if params.print_diagnostics: @@ -738,6 +822,7 @@ def remove_short_and_long_utt(c: Cut): params=params, model=model, optimizer=optimizer, + sampler=train_dl.sampler, rank=rank, ) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa00..06eacd7361 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: + print("Error getting eigenvalues, trying another method") + eigs, _ = torch.eigs(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) From 05e30d0c461f2428a12a8a13d980f14320bf13be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 21:15:00 +0800 Subject: [PATCH 108/185] Add max-abs=6, debugged version --- .../ASR/pruned_transducer_stateless2/conformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index cb46528409..c6470b4a2a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -176,13 +176,13 @@ def __init__( self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, - max_positive=0.55) + max_positive=0.55, + max_abs=6.0) self.dropout = nn.Dropout(dropout) @@ -232,7 +232,7 @@ def forward( # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.balancer(self.norm_final(self.pre_norm_final(src))) + src = self.norm_final(self.balancer(src)) return src From 11a04c50ae15505c7c480963203531abe0c65e98 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 21:29:24 +0800 Subject: [PATCH 109/185] Change 0.025,0.05 to 0.01 in initializations --- .../ASR/pruned_transducer_stateless2/conformer.py | 4 ++-- .../ASR/pruned_transducer_stateless2/scaling.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index c6470b4a2a..f778c92266 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -440,8 +440,8 @@ def _pos_bias_v(self): return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.05) - nn.init.normal_(self.pos_bias_v, std=0.05) + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) def forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index b358e5fa2e..f2423492f8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -153,7 +153,7 @@ def __init__(self, *args, self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - std = 0.025 + std = 0.01 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -188,7 +188,7 @@ def __init__(self, *args, self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.025 + std = 0.01 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -229,7 +229,7 @@ def __init__(self, *args, initial_scale=1.0, **kwargs): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.025 + std = 0.01 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -451,7 +451,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona def reset_parameters(self) -> None: - std = 0.025 + std = 0.01 nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 2eef001d39dcd68f230a8072cac9350b78b9f950 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 23:59:26 +0800 Subject: [PATCH 110/185] Fix balancer code --- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index bf96b41f97..909f9a74ca 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -176,13 +176,13 @@ def __init__( self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, - max_positive=0.55) + max_positive=0.55, + max_positive=6.0) self.dropout = nn.Dropout(dropout) @@ -232,7 +232,7 @@ def forward( # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.balancer(self.norm_final(self.pre_norm_final(src))) + src = self.norm_final(self.balancer(src)) return src From b7e84d5d77cb313579d54b58cc4be3f660af9038 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 23:59:53 +0800 Subject: [PATCH 111/185] Whitespace fix --- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 909f9a74ca..f7b96a6a1e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -230,7 +230,7 @@ def forward( src = src + self.dropout(self.conv_module(src)) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) From b82a505dfc003ea9b919dc49d60d780837e927bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 12:30:48 +0800 Subject: [PATCH 112/185] Reduce initial pruned_loss scale from 0.01 to 0.0 --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d28a8a060a..b9409127e0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -496,7 +496,7 @@ def compute_loss( warmup_mode=warmup_mode, ) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.01 if warmup_mode else pruned_loss)) + (pruned_loss * 0.0 if warmup_mode else pruned_loss)) assert loss.requires_grad == is_training From 4004ca81d84cda612265bd6919bea168e39601da Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 13:32:24 +0800 Subject: [PATCH 113/185] Increase warm_step (and valid_interval) --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b9409127e0..096f93d77f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -295,7 +295,7 @@ def get_params() -> AttributeDict: # parameters for decoder "embedding_dim": 512, # parameters for Noam - "warm_step": 30000, # For the 100h subset, use 8k + "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } @@ -689,8 +689,8 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) if params.full_libri is False: - params.valid_interval = 800 - params.warm_step = 16000 + params.valid_interval = 1600 + params.warm_step = 30000 fix_random_seed(params.seed) if world_size > 1: From cef634870300c6d2a00f6b538ccc5d64975d0766 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 13:50:54 +0800 Subject: [PATCH 114/185] Change max-abs from 6 to 10 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f778c92266..d90dd34e19 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -182,7 +182,7 @@ def __init__( self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55, - max_abs=6.0) + max_abs=10.0) self.dropout = nn.Dropout(dropout) From 9a8aa1f54ab4154571974eea3c795f0b7ad49758 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 15:36:20 +0800 Subject: [PATCH 115/185] Change how warmup works. --- .../pruned_transducer_stateless2/conformer.py | 221 +++--------------- .../ASR/pruned_transducer_stateless2/model.py | 7 +- .../ASR/pruned_transducer_stateless2/train.py | 13 +- 3 files changed, 38 insertions(+), 203 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index d90dd34e19..83bcc3f3e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -88,7 +88,7 @@ def __init__( def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -97,6 +97,10 @@ def forward( x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) @@ -113,7 +117,7 @@ def forward( mask = make_pad_mask(lengths) x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup_mode=warmup_mode) # (T, N, C) + warmup=warmup) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -193,6 +197,8 @@ def forward( pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + position: float = 0.0 ) -> Tensor: """ Pass the input through the encoder layer. @@ -202,6 +208,11 @@ def forward( pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective activation of layers; if < 1.0, it's possible that + not all modules will be included. + position: the position of this module in the encoder stack (relates to + warmup); a value 0 <= position < 1.0. + Shape: src: (S, N, E). @@ -210,9 +221,9 @@ def forward( src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) + src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), + alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) # multi-headed self-attention module @@ -224,13 +235,16 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = torch.add(src, self.dropout(src_att), + alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = torch.add(src, self.dropout(self.conv_module(src)), + alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + src = torch.add(src, self.dropout(self.feed_forward(src)), + alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) src = self.norm_final(self.balancer(src)) @@ -262,10 +276,6 @@ def __init__(self, encoder_layer: nn.Module, num_layers: int, assert num_layers - 1 not in aux_layers self.num_layers = num_layers num_channels = encoder_layer.d_model - self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0) def forward( self, @@ -273,7 +283,7 @@ def forward( pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup_mode: bool = False + warmup: float = 1.0 ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -293,7 +303,7 @@ def forward( """ output = src - outputs = [] + num_layers = len(self.layers) for i, mod in enumerate(self.layers): output = mod( @@ -301,11 +311,10 @@ def forward( pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + position=(i / num_layers), ) - if i in self.aux_layers: - outputs.append(output) - output = self.combiner(outputs, warmup_mode) return output @@ -922,187 +931,9 @@ def forward(self, x: Tensor) -> Tensor: return x -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - def __init__(self, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0) -> None: - """ - Args: - num_inputs: The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, - or combinations of layers, to use, is conceptually as follows. - With probability `pure_prob`: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super(RandomCombine, self).__init__() - assert pure_prob >= 0 and pure_prob <= 1 - assert final_weight > 0 and final_weight < 1 - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev= stddev - - self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - - - def forward(self, inputs: Sequence[Tensor], - warmup_mode: bool) -> Tensor: - """ - Forward function. - Args: - inputs: a list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - a Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not (self.training and warmup_mode): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, - num_frames) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: - """ - Return a tensor of random weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), such that - ans.sum(dim=1) is all ones. - - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) - - def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with - exactly one weight equal to 1.0 on each frame. - """ - - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, - final, nonfinal) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) - return ans - - - def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that - sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. - """ - logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev - logprobs[:,-1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") - num_inputs = 3 - num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev) - - x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - - y = m(x, True) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. if __name__ == '__main__': - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 @@ -1110,4 +941,4 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup_mode=True) + warmup=0.5) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index e83d18e3eb..faaebc477d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,7 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - warmup_mode: bool = False + warmup: float = 1.0, ) -> torch.Tensor: """ Args: @@ -87,6 +87,9 @@ def forward( lm_scale: The scale to smooth the loss with lm (output of predictor network) part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. Returns: Return the transducer loss. @@ -102,7 +105,7 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode=warmup_mode) + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 096f93d77f..d4a2e83d5f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -454,7 +454,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup_mode: bool = False + warmup: float = 1.0 ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -471,6 +471,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. """ device = model.device feature = batch["inputs"] @@ -493,10 +495,10 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - warmup_mode=warmup_mode, + warmup=warmup, ) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.0 if warmup_mode else pruned_loss)) + (pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) assert loss.requires_grad == is_training @@ -601,7 +603,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup_mode=(params.batch_idx_train < params.model_warm_step) + warmup=(params.batch_idx_train / params.model_warm_step) ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -855,7 +857,6 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup_mode=True # may use slightly more memory ) loss.backward() optimizer.step() From aab72bc2a546872ac08a4396b382810b90af1cba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Mar 2022 13:10:54 +0800 Subject: [PATCH 116/185] Add changes from master to decode.py, train.py --- .../pruned_transducer_stateless2/decode.py | 27 ++++++++++++++----- .../ASR/pruned_transducer_stateless2/train.py | 19 ++++++++++--- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index ad76411c04..8e924bf96c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -71,6 +71,7 @@ beam_search, fast_beam_search, greedy_search, + greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model @@ -191,7 +192,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -261,6 +262,24 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -280,12 +299,6 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d4a2e83d5f..01cf289f59 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -398,12 +398,16 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", - "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] - params["start_epoch"] = saved_params["cur_epoch"] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] return saved_params @@ -762,11 +766,20 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) - if checkpoints and "sampler" in checkpoints: + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None From 1f548548d2875ebc6ec7f7d526d0500c5e83b18e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Mar 2022 15:06:06 +0800 Subject: [PATCH 117/185] Simplify the warmup code; max_abs 10->6 --- .../pruned_transducer_stateless2/conformer.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 83bcc3f3e4..a817773535 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -186,7 +186,7 @@ def __init__( self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55, - max_abs=10.0) + max_abs=6.0) self.dropout = nn.Dropout(dropout) @@ -198,7 +198,6 @@ def forward( src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - position: float = 0.0 ) -> Tensor: """ Pass the input through the encoder layer. @@ -208,11 +207,10 @@ def forward( pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective activation of layers; if < 1.0, it's possible that - not all modules will be included. - position: the position of this module in the encoder stack (relates to - warmup); a value 0 <= position < 1.0. - + warmup: controls selective activation of layers; if < 0.5, it's possible that + not all modules will be included. Actually we add the + feed_forward_macaron and self_attn modules at warmup=0.0 + and the conv_module and feed_forward at warmup=0.5. Shape: src: (S, N, E). @@ -223,7 +221,7 @@ def forward( """ # macaron style feed forward module src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), - alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) + alpha=(0.0 if warmup < 0.0 else 1.0)) # multi-headed self-attention module @@ -236,15 +234,15 @@ def forward( key_padding_mask=src_key_padding_mask, )[0] src = torch.add(src, self.dropout(src_att), - alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) + alpha=(0.0 if warmup < 0.0 else 1.0)) # convolution module src = torch.add(src, self.dropout(self.conv_module(src)), - alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) + alpha=(0.0 if warmup < 0.5 else 1.0)) # feed forward module src = torch.add(src, self.dropout(self.feed_forward(src)), - alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) + alpha=(0.0 if warmup < 0.5 else 1.0)) src = self.norm_final(self.balancer(src)) @@ -311,8 +309,7 @@ def forward( pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - position=(i / num_layers), + warmup=warmup-0.5*(i / num_layers) ) return output From 4b650e9f015a8cef28f5a2a0574b8b3d250fcea8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Mar 2022 20:34:33 +0800 Subject: [PATCH 118/185] Make warmup work by scaling layer contributions; leave residual layer-drop --- .../pruned_transducer_stateless2/conformer.py | 32 +++++++++++++------ .../ASR/pruned_transducer_stateless2/train.py | 11 +++++-- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a817773535..64030ef904 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -219,9 +219,23 @@ def forward( src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + src_orig = src + # when warmup == 0.0, alpha is always 0.1, but it gradually changes to + # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not + # 0.0 is that it gives us a gradient so we can learn something when we are not + # being very useful. The occasional 1.0 will ensure, via self.balancer, that + # the outputs of our modules don't get scaled up too much. + + # min(0.1, warmup) + # is used in place of warmup to ensure that even at the start of the warm-up + # period we sometimes use scale 1.0; this ensures that the modules do not + # compensate for the small scale by just producing larger output. + warmup = max(warmup, 0.1) + warmup = min(warmup, 0.95) # effectively, layer-drop. + alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 + # macaron style feed forward module - src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), - alpha=(0.0 if warmup < 0.0 else 1.0)) + src = torch.add(src, self.dropout(self.feed_forward_macaron(src))) # multi-headed self-attention module @@ -233,19 +247,19 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = torch.add(src, self.dropout(src_att), - alpha=(0.0 if warmup < 0.0 else 1.0)) + src = torch.add(src, self.dropout(src_att)) # convolution module - src = torch.add(src, self.dropout(self.conv_module(src)), - alpha=(0.0 if warmup < 0.5 else 1.0)) + src = torch.add(src, self.dropout(self.conv_module(src))) # feed forward module - src = torch.add(src, self.dropout(self.feed_forward(src)), - alpha=(0.0 if warmup < 0.5 else 1.0)) + src = torch.add(src, self.dropout(self.feed_forward(src))) src = self.norm_final(self.balancer(src)) + if alpha != 1.0: + src = alpha * src + (1-alpha) * src_orig + return src @@ -309,7 +323,7 @@ def forward( pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup-0.5*(i / num_layers) + warmup=warmup, ) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 01cf289f59..35991f5e96 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 4000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -501,8 +501,15 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = (0.0 if warmup < 1.0 else + (0.1 if warmup > 1.0 and warmup < 2.0) else + 1.0) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) + pruned_loss_scale * pruned_loss) assert loss.requires_grad == is_training From d2ed3dfc90fa05c63433c5cc7e627bb03de209cc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Mar 2022 20:35:11 +0800 Subject: [PATCH 119/185] Fix bug --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 35991f5e96..13ba990170 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -506,8 +506,8 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = (0.0 if warmup < 1.0 else - (0.1 if warmup > 1.0 and warmup < 2.0) else - 1.0) + (0.1 if warmup > 1.0 and warmup < 2.0 else + 1.0)) loss = (params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss) From 0e694739f2a33344eac9cf8b0398aa876f469853 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Mar 2022 23:28:52 +0800 Subject: [PATCH 120/185] Fix test mode with random layer dropout --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 64030ef904..fae91aa718 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -231,7 +231,8 @@ def forward( # period we sometimes use scale 1.0; this ensures that the modules do not # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) - warmup = min(warmup, 0.95) # effectively, layer-drop. + if self.training: + warmup = min(warmup, 0.95) # effectively, layer-drop. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module From 26a1730392163e64499672c2847ef2e10bf3bc5e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Mar 2022 14:46:27 +0800 Subject: [PATCH 121/185] Add random-number-setting function in dataloader --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a460c8eb84..a0356f68a2 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -22,6 +22,8 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import torch +import lhotse from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( @@ -301,12 +303,19 @@ def train_dataloaders( logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have previously been + # set in the main process. + seed = torch.randint(0, 100000, ()).item() + def worker_init_fn(worker_id: int): + lhotse.utils.fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl From 8a38d9a855b57be5e976727084d4980aa0fd5b2a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Mar 2022 15:43:47 +0800 Subject: [PATCH 122/185] Fix/patch how fix_random_seed() is imported. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a0356f68a2..3efe7ec7a7 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -import lhotse +from lhotse.utils import fix_random_seed from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( @@ -307,7 +307,7 @@ def train_dataloaders( # set in the main process. seed = torch.randint(0, 100000, ()).item() def worker_init_fn(worker_id: int): - lhotse.utils.fix_random_seed(seed + worker_id) + fix_random_seed(seed + worker_id) train_dl = DataLoader( train, From b43468bb67502b87296387b6a65048a85558ab04 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Mar 2022 19:36:33 +0800 Subject: [PATCH 123/185] Reduce layer-drop prob --- .../pruned_transducer_stateless2/conformer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index fae91aa718..69a7af6a93 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -222,21 +222,20 @@ def forward( src_orig = src # when warmup == 0.0, alpha is always 0.1, but it gradually changes to # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not - # 0.0 is that it gives us a gradient so we can learn something when we are not - # being very useful. The occasional 1.0 will ensure, via self.balancer, that - # the outputs of our modules don't get scaled up too much. - + # 0.0 is that it gives us a gradient so we can learn something when we are turned + # off. + # # min(0.1, warmup) # is used in place of warmup to ensure that even at the start of the warm-up # period we sometimes use scale 1.0; this ensures that the modules do not # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: - warmup = min(warmup, 0.95) # effectively, layer-drop. + warmup = min(warmup, 0.98) # effectively, layer-drop with 1-in-50 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module - src = torch.add(src, self.dropout(self.feed_forward_macaron(src))) + src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module @@ -248,13 +247,13 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = torch.add(src, self.dropout(src_att)) + src = src + self.dropout(src_att) # convolution module - src = torch.add(src, self.dropout(self.conv_module(src))) + src = src + self.dropout(self.conv_module(src)) # feed forward module - src = torch.add(src, self.dropout(self.feed_forward(src))) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) From 953aecf5e38811edc11123d26292cd1d397e11aa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Mar 2022 00:25:32 +0800 Subject: [PATCH 124/185] Reduce layer-drop prob after warmup to 1 in 100 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 69a7af6a93..85a3b4575c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -231,7 +231,7 @@ def forward( # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: - warmup = min(warmup, 0.98) # effectively, layer-drop with 1-in-50 prob. + warmup = min(warmup, 0.99) # effectively, layer-drop with 1-in-100 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module From 8a8134b9e54e0a1b1cbda59cc1a38fe7cccb16b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 13:18:58 +0800 Subject: [PATCH 125/185] Change power of lr-schedule from -0.5 to -0.333 --- .../ASR/pruned_transducer_stateless2/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py index 3fa847f4f2..aa091877c7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -391,7 +391,8 @@ def rate(self, step=None): return ( self.factor * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) ) def zero_grad(self): From 262388134d3b31dc6ec42fa15b99804d23e23d44 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Mar 2022 11:18:16 +0800 Subject: [PATCH 126/185] Increase model_warm_step to 4k --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 13ba990170..c1e836903d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } ) From 2cde99509fda0dc6fec55eab7504cacc44b6c0fc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Mar 2022 23:21:42 +0800 Subject: [PATCH 127/185] Change max-keep-prob to 0.95 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 85a3b4575c..9c8302926c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -231,7 +231,7 @@ def forward( # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: - warmup = min(warmup, 0.99) # effectively, layer-drop with 1-in-100 prob. + warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module From 11124b03eaed547e057f302bf02d0d75b91ae58b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Mar 2022 20:32:14 +0800 Subject: [PATCH 128/185] Refactoring and simplifying conformer and frontend --- .../pruned_transducer_stateless2/conformer.py | 115 +++++++++++++--- .../subsampling.py | 127 +++--------------- .../ASR/pruned_transducer_stateless2/train.py | 2 - .../transformer.py | 5 +- 4 files changed, 115 insertions(+), 134 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 9c8302926c..6b625513e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -16,6 +16,7 @@ # limitations under the License. import copy +from encoder_interface import EncoderInterface import math import warnings from typing import Optional, Tuple, Sequence @@ -23,12 +24,11 @@ import torch from torch import Tensor, nn -from transformer import Transformer from icefall.utils import make_pad_mask -class Conformer(Transformer): +class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features @@ -40,7 +40,6 @@ class Conformer(Transformer): num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. """ @@ -55,22 +54,22 @@ def __init__( num_encoder_layers: int = 12, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, aux_layer_period: int = 3 ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - output_dim=output_dim, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - ) + super(Conformer, self).__init__() + + self.num_features = num_features + self.output_dim = output_dim + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -80,11 +79,13 @@ def __init__( dim_feedforward, dropout, cnn_module_kernel, - normalize_before, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.normalize_before = normalize_before + + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) + ) def forward( @@ -136,7 +137,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -152,7 +152,6 @@ def __init__( dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() self.d_model = d_model @@ -942,6 +941,80 @@ def forward(self, x: Tensor) -> Tensor: return x +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, in_channels: int, + out_channels: int, + layer1_channels: int = 64, + layer2_channels: int = 128) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, out_channels=layer1_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, out_channels=layer2_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x if __name__ == '__main__': diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index c2da23adc8..12ca09a178 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -32,34 +32,43 @@ class Conv2dSubsampling(nn.Module): https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa """ - def __init__(self, idim: int, odim: int) -> None: + def __init__(self, in_channels: int, + out_channels: int, + layer1_channels: int = 64, + layer2_channels: int = 128) -> None: """ Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 """ - assert idim >= 7 + assert in_channels >= 7 super().__init__() self.conv = nn.Sequential( ScaledConv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 + in_channels=1, out_channels=layer1_channels, + kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + in_channels=layer1_channels, out_channels=layer2_channels, + kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ) - self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. - self.out_norm = BasicNorm(odim, learn_eps=False) + self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, @@ -86,99 +95,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.out_norm(x) x = self.out_balancer(x) return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) - self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - x = self.out_norm(x) - x = self.out_balancer(x) - return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c1e836903d..237eb8bbdf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -291,7 +291,6 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, - "vgg_frontend": False, # parameters for decoder "embedding_dim": 512, # parameters for Noam @@ -314,7 +313,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, ) return encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py index aa091877c7..a58702e1dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -78,10 +78,7 @@ def __init__( # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = PositionalEncoding(d_model, dropout) From 4e453a4bf9c77bfaa19a955921a9e7218548b2eb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Mar 2022 23:41:13 +0800 Subject: [PATCH 129/185] Rework conformer, remove some code. --- .../pruned_transducer_stateless2/conformer.py | 90 +++- .../subsampling.py | 97 ---- .../ASR/pruned_transducer_stateless2/train.py | 3 +- .../transformer.py | 416 ------------------ 4 files changed, 90 insertions(+), 516 deletions(-) delete mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py delete mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 6b625513e4..0b9d64ee9a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -69,7 +69,7 @@ def __init__( # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -1017,6 +1017,94 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + if __name__ == '__main__': feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py deleted file mode 100644 index 12ca09a178..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn -from torch import Tensor -from typing import Tuple -from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__(self, in_channels: int, - out_channels: int, - layer1_channels: int = 64, - layer2_channels: int = 128) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - """ - assert in_channels >= 7 - super().__init__() - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, out_channels=layer2_channels, - kernel_size=3, stride=2 - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 237eb8bbdf..8d51429375 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -44,7 +44,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer +from conformer import Conformer, Noam from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -54,7 +54,6 @@ from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from transformer import Noam from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py deleted file mode 100644 index a58702e1dc..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear - -from icefall.utils import make_pad_mask - - -class Transformer(EncoderInterface): - def __init__( - self, - num_features: int, - output_dim: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - output_dim: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder. - num_encoder_layers: - Number of encoder layers. - dropout: - Dropout in encoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - """ - super().__init__() - - self.num_features = num_features - self.output_dim = output_dim - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) - ) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - Returns: - Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 - assert x.size(0) == lengths.max().item() - - mask = make_pad_mask(lengths) - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return logits, lengths - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) From 1b8d7defd06c7e12b47e0bde0c8092a053d7f377 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 00:44:18 +0800 Subject: [PATCH 130/185] Reduce 1st conv channels from 64 to 32 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0b9d64ee9a..628d31d4b1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -954,7 +954,7 @@ class Conv2dSubsampling(nn.Module): def __init__(self, in_channels: int, out_channels: int, - layer1_channels: int = 64, + layer1_channels: int = 32, layer2_channels: int = 128) -> None: """ Args: From ca6337b78aaedff4404135558cf99f9ad7ab7123 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:11:32 +0800 Subject: [PATCH 131/185] Add another convolutional layer --- .../ASR/pruned_transducer_stateless2/conformer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 628d31d4b1..eb937e0c39 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -954,8 +954,9 @@ class Conv2dSubsampling(nn.Module): def __init__(self, in_channels: int, out_channels: int, - layer1_channels: int = 32, - layer2_channels: int = 128) -> None: + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128) -> None: """ Args: in_channels: @@ -973,7 +974,7 @@ def __init__(self, in_channels: int, self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 + kernel_size=3, ), ActivationBalancer(channel_dim=1), DoubleSwish(), @@ -983,8 +984,14 @@ def __init__(self, in_channels: int, ), ActivationBalancer(channel_dim=1), DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, out_channels=layer3_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), ) - self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. From 21a099b110e2831664cd79b5fd982b87606c512c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:18:04 +0800 Subject: [PATCH 132/185] Fix padding bug --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index eb937e0c39..853d6747bd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -974,7 +974,7 @@ def __init__(self, in_channels: int, self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, + kernel_size=3, padding=1, ), ActivationBalancer(channel_dim=1), DoubleSwish(), From 7c46c3b0d4ae7237ea1ac909d44a605507c27a77 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:20:04 +0800 Subject: [PATCH 133/185] Remove dropout in output layer --- .../ASR/pruned_transducer_stateless2/conformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0b9d64ee9a..a8475c21ea 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -83,10 +83,7 @@ def __init__( self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) - ) - + self.encoder_output_layer = ScaledLinear(d_model, output_dim) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 From 37ab0bcfa56fa063c7e5aadfe3f0f207c53e5518 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:46:23 +0800 Subject: [PATCH 134/185] Reduce speed of some components --- .../pruned_transducer_stateless2/conformer.py | 12 +++- .../pruned_transducer_stateless2/decoder.py | 7 ++ .../pruned_transducer_stateless2/scaling.py | 67 +++++++++++++------ 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a8475c21ea..0d3b0aa029 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -967,16 +967,24 @@ def __init__(self, in_channels: int, """ assert in_channels >= 7 super().__init__() + + # This initial_speed is to slightly slow down the relative speed of + # training during the warmup phase by increasing the magnitude of the + # initial parameter values. The intention is to allow us to + # use a higher lr_factor. + initial_speed = 0.5 self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 + kernel_size=3, stride=2, + initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, - kernel_size=3, stride=2 + kernel_size=3, stride=2, + initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 13e45e03b7..3470b647f6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -55,10 +55,17 @@ def __init__( 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() + + # This initial_speed is to slightly slow down the relative speed of + # training during the warmup phase by increasing the magnitude of the + # initial parameter values. The intention is to allow us to + # use a higher lr_factor. + initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, + initial_speed=initial_speed ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f2423492f8..4c45205ce3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear): (affects the initialization of weight_scale and bias_scale). Another option, if you want to do something like this, is to re-initialize the parameters. - - Note: it uses the default initialization for the weight and bias, - inherited from nn.Linear. For modules with small fan-in, this - may be larger than optimal. + initial_speed: this affects how fast the parameter will + learn near the start of training; you can set it to a + value less than one if you suspect that a module + is contributing to instability near the start of training. + Nnote: regardless of the use of this option, it's best to + use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. """ def __init__(self, *args, initial_scale: float = 1.0, + initial_speed: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -150,10 +155,10 @@ def __init__(self, *args, else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in nn.Linear + self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -176,8 +181,11 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv1d(nn.Conv1d): + # See docs for ScaledLinear def __init__(self, *args, - initial_scale=1.0, **kwargs): + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -185,10 +193,10 @@ def __init__(self, *args, self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class + self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -218,7 +226,11 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + # See docs for ScaledLinear + def __init__(self, *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -226,10 +238,10 @@ def __init__(self, *args, initial_scale=1.0, **kwargs): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class + self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -350,7 +362,11 @@ def forward(self, x: Tensor) -> Tensor: class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + r"""This is a modified version of nn.Embedding that introduces a learnable scale + on the parameters. Note: due to how we initialize it, it's best used with + schedulers like Noam that have a warmup period. + + It is a simple lookup table that stores embeddings of a fixed dictionary and size. This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding @@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module): sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. + initial_speed (float, optional): This affects how fast the parameter will + learn near the start of training; you can set it to a value less than + one if you suspect that a module is contributing to instability near + the start of training. Nnote: regardless of the use of this option, + it's best to use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + + Attributes: weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from :math:`\mathcal{N}(0, 1)` @@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module): [ 0.1535, -2.0309, 0.9315], [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) + """ __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'scale_grad_by_freq', 'sparse'] @@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + sparse: bool = False, + initial_speed: float = 1.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -446,12 +473,12 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() + self.reset_parameters(initial_speed) - def reset_parameters(self) -> None: - std = 0.01 + def reset_parameters(self, initial_speed: float = 1.0) -> None: + std = 0.01 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 709c387ce63a9e1eeee2e34de831f27f2b72b9cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 21:40:22 +0800 Subject: [PATCH 135/185] Initial refactoring to remove unnecessary vocab_size --- .../pruned_transducer_stateless2/conformer.py | 25 +++++++++++-------- .../pruned_transducer_stateless2/decoder.py | 3 +-- .../pruned_transducer_stateless2/joiner.py | 9 +++---- .../ASR/pruned_transducer_stateless2/model.py | 11 ++++++-- .../ASR/pruned_transducer_stateless2/train.py | 7 +++--- 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index d8b1847523..03a47927fb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,9 +32,10 @@ class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Number of output dimension + output_dim (int): Model output dimension. If not equal to the encoder dimension, + we will project to the output. subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension + d_model (int): attention dimension, also the output dimension nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -42,7 +43,6 @@ class Conformer(EncoderInterface): cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ - def __init__( self, num_features: int, @@ -59,7 +59,6 @@ def __init__( super(Conformer, self).__init__() self.num_features = num_features - self.output_dim = output_dim self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") @@ -83,7 +82,11 @@ def __init__( self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.encoder_output_layer = ScaledLinear(d_model, output_dim) + if output_dim == d_model: + self.encoder_output_layer = Identity() + else: + self.encoder_output_layer = ScaledLinear(d_model, output_dim, + initial_speed=0.5) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -101,9 +104,9 @@ def forward( to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. + - embeddings: its shape is (batch_size, output_seq_len, d_model) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. """ x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) @@ -117,10 +120,10 @@ def forward( x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths + return x, lengths class ConformerEncoderLayer(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 3470b647f6..a442feeea2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -68,6 +68,7 @@ def __init__( initial_speed=initial_speed ) self.blank_id = blank_id + self.output_linear = ScaledLinear(embedding_dim, embedding_dim) assert context_size >= 1, context_size self.context_size = context_size @@ -81,8 +82,6 @@ def __init__( groups=embedding_dim, bias=False, ) - self.output_linear = ScaledLinear(embedding_dim, vocab_size) - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 61bfe81867..973a89bfe9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -20,11 +20,10 @@ from scaling import ScaledLinear class Joiner(nn.Module): - def __init__(self, input_dim: int, inner_dim: int, output_dim: int): + def __init__(self, input_dim: int, output_dim: int): super().__init__() - self.inner_linear = ScaledLinear(input_dim, inner_dim) - self.output_linear = ScaledLinear(inner_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -43,8 +42,6 @@ def forward( logit = encoder_out + decoder_out - logit = self.inner_linear(torch.tanh(logit)) - - output = self.output_linear(F.relu(logit)) + logit = self.output_linear(torch.tanh(logit)) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index faaebc477d..2f102bdf8d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos @@ -33,6 +34,8 @@ def __init__( encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + embedding_dim: int, + vocab_size: int ): """ Args: @@ -58,6 +61,10 @@ def __init__( self.decoder = decoder self.joiner = joiner + # could perhaps separate this into 2 linear projections, one + # for lm and one for am. + self.simple_joiner = nn.Linear(embedding_dim, vocab_size) + def forward( self, x: torch.Tensor, @@ -133,8 +140,8 @@ def forward( boundary[:, 3] = x_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=decoder_out, - am=encoder_out, + lm=self.simple_joiner(decoder_out), + am=self.simple_joiner(encoder_out), symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 8d51429375..649234f0f9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.vocab_size, + output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, + input_dim=params.embedding_dim, output_dim=params.vocab_size, ) return joiner @@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, + embedding_dim=params.embedding_dim, + vocab_size=params.vocab_size, ) return model From f87811e65c1f9cdace638122df5f29c150a50b60 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 21:41:46 +0800 Subject: [PATCH 136/185] Fix RE identity --- .../ASR/pruned_transducer_stateless2/conformer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 03a47927fb..528cc48f49 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -83,7 +83,7 @@ def __init__( aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) if output_dim == d_model: - self.encoder_output_layer = Identity() + self.encoder_output_layer = nn.Identity() else: self.encoder_output_layer = ScaledLinear(d_model, output_dim, initial_speed=0.5) @@ -936,10 +936,6 @@ def forward(self, x: Tensor) -> Tensor: return x.permute(2, 0, 1) -class Identity(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - return x - class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). From a2aca9f64371b3be66cba65bac9f6b60346a9126 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 21:42:15 +0800 Subject: [PATCH 137/185] Bug-fix --- egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 973a89bfe9..d76a913a55 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -44,4 +44,4 @@ def forward( logit = self.output_linear(torch.tanh(logit)) - return output + return logit From 0599f382810c9f9f2bbad39dacd3c8159bd43a06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 11:53:54 +0800 Subject: [PATCH 138/185] Add final dropout to conformer --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 528cc48f49..8d4057e710 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -82,6 +82,7 @@ def __init__( self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.final_dropout = nn.Dropout(p=dropout) if output_dim == d_model: self.encoder_output_layer = nn.Identity() else: @@ -120,6 +121,7 @@ def forward( x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + x = self.final_dropout(x) x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) From f47fe8337aec12d8d7a005855e763b695f37e9d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 12:16:08 +0800 Subject: [PATCH 139/185] Remove some un-used code --- .../ASR/pruned_transducer_stateless2/conformer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 528cc48f49..abe30633c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -54,7 +54,6 @@ def __init__( num_encoder_layers: int = 12, dropout: float = 0.1, cnn_module_kernel: int = 31, - aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__() @@ -79,8 +78,7 @@ def __init__( dropout, cnn_module_kernel, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) if output_dim == d_model: self.encoder_output_layer = nn.Identity() @@ -277,16 +275,13 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int, - aux_layers: Sequence[int]) -> None: + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) - self.aux_layers = set(aux_layers + [num_layers - 1]) - assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.d_model + def forward( self, From f75d40c725f6d9ebacc5e02581066dc5ec4de762 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 12:18:31 +0800 Subject: [PATCH 140/185] Replace nn.Linear with ScaledLinear in simple joiner --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 2f102bdf8d..f1a3d4d113 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -63,7 +63,7 @@ def __init__( # could perhaps separate this into 2 linear projections, one # for lm and one for am. - self.simple_joiner = nn.Linear(embedding_dim, vocab_size) + self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) def forward( self, From c67ae0f3a132189da402eca9d4886e664699862d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:02:40 +0800 Subject: [PATCH 141/185] Make 2 projections.. --- .../ASR/pruned_transducer_stateless2/joiner.py | 3 ++- .../ASR/pruned_transducer_stateless2/model.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index d76a913a55..b9c4653986 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -23,7 +23,8 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() - self.output_linear = ScaledLinear(input_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim, + initial_speed=0.5) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index f1a3d4d113..ab729a4294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -61,9 +61,15 @@ def __init__( self.decoder = decoder self.joiner = joiner - # could perhaps separate this into 2 linear projections, one - # for lm and one for am. - self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) + self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size) + self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size) + with torch.no_grad(): + # Initialize the two projections to be the same; this will be + # convenient for the real joiner, which adds the endcoder + # (acoustic-model/am) and decoder (language-model/lm) embeddings + self.simple_lm_proj.weight[:] = self.simple_am_proj.weight + self.simple_lm_proj.bias[:] = self.simple_am_proj.bias + def forward( self, @@ -140,8 +146,8 @@ def forward( boundary[:, 3] = x_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=self.simple_joiner(decoder_out), - am=self.simple_joiner(encoder_out), + lm=self.simple_lm_proj(decoder_out), + am=self.simple_am_proj(encoder_out), symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale, From e59db01b7c599afb0c780a33289b4b86c6579afe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:03:26 +0800 Subject: [PATCH 142/185] Reduce initial_speed --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index ab729a4294..0355c4531b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -61,8 +61,10 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size) - self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size) + self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, + initial_speed=0.5) + self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, + initial_speed=0.5) with torch.no_grad(): # Initialize the two projections to be the same; this will be # convenient for the real joiner, which adds the endcoder From ec54fa85cc9cd8cba6b87c0599464fa499523e27 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:04:09 +0800 Subject: [PATCH 143/185] Use initial_speed=0.5 --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index f1a3d4d113..9fef48fcc1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -63,7 +63,8 @@ def __init__( # could perhaps separate this into 2 linear projections, one # for lm and one for am. - self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) + self.simple_joiner = ScaledLinear(embedding_dim, vocab_size, + initial_speed=0.5) def forward( self, From 025d6909951502ec35187129c4b05b7d40f3b85b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:39:56 +0800 Subject: [PATCH 144/185] Reduce initial_speed further from 0.5 to 0.25 --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 9fef48fcc1..83405be363 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -64,7 +64,7 @@ def __init__( # could perhaps separate this into 2 linear projections, one # for lm and one for am. self.simple_joiner = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.5) + initial_speed=0.25) def forward( self, From fcb0dba2cfe1d84ac10472dc3a745ed936053246 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:47:28 +0800 Subject: [PATCH 145/185] Reduce initial_speed from 0.5 to 0.25 --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 0355c4531b..47a7169b1a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -62,9 +62,9 @@ def __init__( self.joiner = joiner self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.5) + initial_speed=0.25) self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.5) + initial_speed=0.25) with torch.no_grad(): # Initialize the two projections to be the same; this will be # convenient for the real joiner, which adds the endcoder From e6637132584c1c0287e8cdc80cb0f0e5b22cce6b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 14:43:49 +0800 Subject: [PATCH 146/185] Change how warmup is applied. --- .../pruned_transducer_stateless2/conformer.py | 24 ++++++------------- .../ASR/pruned_transducer_stateless2/model.py | 2 +- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 12095810e9..704c17dd70 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -206,10 +206,8 @@ def forward( pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective activation of layers; if < 0.5, it's possible that - not all modules will be included. Actually we add the - feed_forward_macaron and self_attn modules at warmup=0.0 - and the conv_module and feed_forward at warmup=0.5. + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. Shape: src: (S, N, E). @@ -219,19 +217,11 @@ def forward( S is the source sequence length, N is the batch size, E is the feature number """ src_orig = src - # when warmup == 0.0, alpha is always 0.1, but it gradually changes to - # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not - # 0.0 is that it gives us a gradient so we can learn something when we are turned - # off. - # - # min(0.1, warmup) - # is used in place of warmup to ensure that even at the start of the warm-up - # period we sometimes use scale 1.0; this ensures that the modules do not - # compensate for the small scale by just producing larger output. - warmup = max(warmup, 0.1) - if self.training: - warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob. - alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely + # bypass it. + alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 83405be363..9fef48fcc1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -64,7 +64,7 @@ def __init__( # could perhaps separate this into 2 linear projections, one # for lm and one for am. self.simple_joiner = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.25) + initial_speed=0.5) def forward( self, From 8caa18e2fe1d03035dbfae1a60878cf727861d44 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 17:30:51 +0800 Subject: [PATCH 147/185] Bug fix to warmup_scale --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 704c17dd70..8778dc5baa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -221,7 +221,7 @@ def forward( warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely # bypass it. - alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale + alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) From 92ec2e356e02cbc9b5493048d6108b1148de40be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 1 Apr 2022 12:22:12 +0800 Subject: [PATCH 148/185] Fix test-mode --- .../ASR/pruned_transducer_stateless2/conformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 8778dc5baa..83de82056f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -221,7 +221,10 @@ def forward( warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely # bypass it. - alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + if self.training: + alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + else: + alpha = 1.0 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) From 45f872c27da3d52dee592552435a46f7ae2cd374 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 1 Apr 2022 19:33:20 +0800 Subject: [PATCH 149/185] Remove final dropout --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 83de82056f..7573addaae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -80,7 +80,6 @@ def __init__( ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.final_dropout = nn.Dropout(p=dropout) if output_dim == d_model: self.encoder_output_layer = nn.Identity() else: @@ -119,7 +118,6 @@ def forward( x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - x = self.final_dropout(x) x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) From e0ba4ef3ec7c0ce3ec1b167820b4a711741da137 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 17:47:12 +0800 Subject: [PATCH 150/185] Make layer dropout rate 0.075, was 0.1. --- .../ASR/pruned_transducer_stateless2/conformer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 7573addaae..07ff0525a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -40,6 +40,7 @@ class Conformer(EncoderInterface): dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ @@ -53,6 +54,7 @@ def __init__( dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, + layer_dropout: float = 0.075, cnn_module_kernel: int = 31, ) -> None: super(Conformer, self).__init__() @@ -76,6 +78,7 @@ def __init__( nhead, dim_feedforward, dropout, + layer_dropout, cnn_module_kernel, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) @@ -149,9 +152,13 @@ def __init__( nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + layer_dropout: float = 0.075, cnn_module_kernel: int = 31, ) -> None: super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( @@ -217,10 +224,10 @@ def forward( src_orig = src warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely - # bypass it. + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. if self.training: - alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1 else: alpha = 1.0 From 8be10d3d6c39dbb51f932c8cebea7cb67055ed92 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:03:21 +0800 Subject: [PATCH 151/185] First draft of model rework --- .../pruned_transducer_stateless2/conformer.py | 11 +------ .../pruned_transducer_stateless2/decoder.py | 17 +++++----- .../pruned_transducer_stateless2/joiner.py | 16 ++++++--- .../ASR/pruned_transducer_stateless2/model.py | 33 ++++++++----------- .../ASR/pruned_transducer_stateless2/train.py | 22 ++++++++----- 5 files changed, 49 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index c7ce3bec20..0deb960ad3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,8 +32,6 @@ class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Model output dimension. If not equal to the encoder dimension, - we will project to the output. subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head @@ -47,7 +45,6 @@ class Conformer(EncoderInterface): def __init__( self, num_features: int, - output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, @@ -83,11 +80,6 @@ def __init__( ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - if output_dim == d_model: - self.encoder_output_layer = nn.Identity() - else: - self.encoder_output_layer = ScaledLinear(d_model, output_dim, - initial_speed=0.5) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -123,7 +115,6 @@ def forward( x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths @@ -1116,7 +1107,7 @@ def load_state_dict(self, state_dict): if __name__ == '__main__': feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index a442feeea2..25a36223df 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -46,8 +46,8 @@ def __init__( Args: vocab_size: Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. + decoder_dim: + Dimension of the input embedding, and of the decoder output. blank_id: The ID of the blank symbol. context_size: @@ -63,23 +63,22 @@ def __init__( initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, - embedding_dim=embedding_dim, + embedding_dim=decoder_dim, padding_idx=blank_id, initial_speed=initial_speed ) self.blank_id = blank_id - self.output_linear = ScaledLinear(embedding_dim, embedding_dim) assert context_size >= 1, context_size self.context_size = context_size self.vocab_size = vocab_size if context_size > 1: self.conv = ScaledConv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, + in_channels=decoder_dim, + out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=embedding_dim, + groups=decoder_dim, bias=False, ) @@ -92,7 +91,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: - Return a tensor of shape (N, U, embedding_dim). + Return a tensor of shape (N, U, decoder_dim). """ y = y.to(torch.int64) embedding_out = self.embedding(y) @@ -108,5 +107,5 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = self.output_linear(F.relu(embedding_out)) + embedding_out = F.relu(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index b9c4653986..64752b9a06 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -20,11 +20,19 @@ from scaling import ScaledLinear class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): + def __init__(self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int): super().__init__() - self.output_linear = ScaledLinear(input_dim, output_dim, - initial_speed=0.5) + # We don't bother giving the 'initial_speed' arg to the decoder + # submodules, because it does not affect the initial convergence of the + # system (only the simple joiner is involved in that). + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) + self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -41,7 +49,7 @@ def forward( assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape == decoder_out.shape - logit = encoder_out + decoder_out + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 0355c4531b..5d4c32ac40 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -34,23 +34,25 @@ def __init__( encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, - embedding_dim: int, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, vocab_size: int ): """ Args: encoder: It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and `logit_lens` of shape (N,). decoder: It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain + is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its + output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() @@ -61,17 +63,10 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, initial_speed=0.5) - with torch.no_grad(): - # Initialize the two projections to be the same; this will be - # convenient for the real joiner, which adds the endcoder - # (acoustic-model/am) and decoder (language-model/lm) embeddings - self.simple_lm_proj.weight[:] = self.simple_am_proj.weight - self.simple_lm_proj.bias[:] = self.simple_am_proj.bias - def forward( self, @@ -133,7 +128,7 @@ def forward( # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - # decoder_out: [B, S + 1, C] + # decoder_out: [B, S + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded) # Note: y does not start with SOS @@ -167,13 +162,13 @@ def forward( s_range=prune_range, ) - # am_pruned : [B, T, prune_range, C] - # lm_pruned : [B, T, prune_range, C] + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( am=encoder_out, lm=decoder_out, ranges=ranges ) - # logits : [B, T, prune_range, C] + # logits : [B, T, prune_range, vocab_size] logits = self.joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c716d457ac..a027a5adc2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -268,7 +268,7 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - attention_dim: Hidden dim for multi-head attention model. + - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. @@ -287,12 +287,14 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, + "encoder_dim": 512, "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, # parameters for decoder - "embedding_dim": 512, + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 4000, # arg given to model, not for lrate @@ -309,7 +311,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_features=params.feature_dim, output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, + d_model=params.encoder_dim, nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, @@ -329,8 +331,10 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.embedding_dim, - output_dim=params.vocab_size, + encoder_dim=params.encoder_dim + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return joiner @@ -344,7 +348,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - embedding_dim=params.embedding_dim, + encoder_dim=params.encoder_dim + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return model @@ -748,7 +754,7 @@ def run(rank, world_size, args): optimizer = Noam( model.parameters(), - model_size=params.attention_dim, + model_size=params.encoder_dim, factor=params.lr_factor, warm_step=params.warm_step, ) From 34500afc43173444309902fb9aea2d6ad2b15d38 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:06:43 +0800 Subject: [PATCH 152/185] Various bug fixes --- .../ASR/pruned_transducer_stateless2/decoder.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 25a36223df..3291ad8775 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -38,7 +38,7 @@ class Decoder(nn.Module): def __init__( self, vocab_size: int, - embedding_dim: int, + decoder_dim: int, blank_id: int, context_size: int, ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a027a5adc2..e8fbb6a716 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -309,7 +309,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, d_model=params.encoder_dim, nhead=params.nhead, @@ -322,7 +321,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, + decoder_dim=params.decoder_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -331,7 +330,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=params.encoder_dim + encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -348,7 +347,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=params.encoder_dim + encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, From 807fcada683ab96aaa585427cc49ce4c21522146 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:15:11 +0800 Subject: [PATCH 153/185] Change learning speed of simple_lm_proj --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 5d4c32ac40..1dd20c5463 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -65,8 +65,7 @@ def __init__( self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, - initial_speed=0.5) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( self, From 9f62a0296cd072083399f6862d1df6bee0134555 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 21:16:39 +0800 Subject: [PATCH 154/185] Revert transducer_stateless/ to state in upstream/master --- .../ASR/transducer_stateless/conformer.py | 396 +++++------------- .../ASR/transducer_stateless/decoder.py | 144 +------ .../transducer_stateless/encoder_interface.py | 2 +- .../ASR/transducer_stateless/joiner.py | 4 +- .../ASR/transducer_stateless/model.py | 3 +- .../ASR/transducer_stateless/train.py | 9 +- .../ASR/transducer_stateless/transformer.py | 4 +- 7 files changed, 108 insertions(+), 454 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index ae95d95b41..488c82386e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,8 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -57,7 +56,6 @@ def __init__( cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -82,13 +80,17 @@ def __init__( cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before - + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -115,8 +117,10 @@ def forward( assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup_mode=warmup_mode) # (T, N, C) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + + if self.normalize_before: + x = self.after_norm(x) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -154,41 +158,42 @@ def __init__( normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), + nn.Linear(d_model, dim_feedforward), + Swish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), + nn.Linear(d_model, dim_feedforward), + Swish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + nn.Linear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - self.norm_final = BasicNorm(d_model) + self.ff_scale = 0.5 - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_positive=6.0) + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) + self.normalize_before = normalize_before def forward( self, @@ -215,10 +220,19 @@ def forward( """ # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) src_att = self.self_attn( src, src, @@ -227,15 +241,28 @@ def forward( attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) # convolution module - src = src + self.dropout(self.conv_module(src)) + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) - src = self.norm_final(self.balancer(src)) + if self.normalize_before: + src = self.norm_final(src) return src @@ -255,20 +282,12 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int, - aux_layers: Sequence[int]) -> None: + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) - self.aux_layers = set(aux_layers + [num_layers - 1]) - assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.d_model - self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0) def forward( self, @@ -276,7 +295,6 @@ def forward( pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup_mode: bool = False ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -296,19 +314,14 @@ def forward( """ output = src - outputs = [] - - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) - if i in self.aux_layers: - outputs.append(output) - output = self.combiner(outputs, warmup_mode) return output @@ -331,6 +344,7 @@ def __init__( """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model + self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -382,6 +396,7 @@ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: """ self.extend_pe(x) + x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 @@ -413,7 +428,6 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, - scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -424,29 +438,25 @@ def __init__( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.scale_speed = scale_speed - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() - def _pos_bias_u(self): - return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() - - def _pos_bias_v(self): - return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + self._reset_parameters() def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.05) - nn.init.normal_(self.pos_bias_v, std=0.05) + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) def forward( self, @@ -506,11 +516,11 @@ def forward( pos_emb, self.embed_dim, self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), + self.in_proj.weight, + self.in_proj.bias, self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), + self.out_proj.weight, + self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -614,12 +624,13 @@ def multi_head_attention_forward( assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -651,7 +662,6 @@ def multi_head_attention_forward( _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -670,7 +680,6 @@ def multi_head_attention_forward( _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -720,7 +729,7 @@ def multi_head_attention_forward( ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) @@ -741,11 +750,11 @@ def multi_head_attention_forward( p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self._pos_bias_u()).transpose( + q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self._pos_bias_v()).transpose( + q_with_bias_v = (q + self.pos_bias_v).transpose( 1, 2 ) # (batch, head, time1, d_k) @@ -765,7 +774,7 @@ def multi_head_attention_forward( attn_output_weights = ( matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 @@ -840,7 +849,7 @@ def __init__( # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = ScaledConv1d( + self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, @@ -848,25 +857,7 @@ def __init__( padding=0, bias=bias, ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, - max_positive=1.0) - - self.depthwise_conv = ScaledConv1d( + self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, @@ -875,22 +866,16 @@ def __init__( groups=channels, bias=bias, ) - - self.deriv_balancer2 = ActivationBalancer(channel_dim=1, - min_positive=0.05, - max_positive=1.0) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, - initial_scale=0.25 ) + self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. @@ -907,14 +892,15 @@ def forward(self, x: Tensor) -> Tensor: # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) - x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) @@ -922,197 +908,13 @@ def forward(self, x: Tensor) -> Tensor: return x.permute(2, 0, 1) -class Identity(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - return x - +class Swish(torch.nn.Module): + """Construct an Swish object.""" -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - def __init__(self, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0) -> None: - """ - Args: - num_inputs: The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, - or combinations of layers, to use, is conceptually as follows. - With probability `pure_prob`: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super(RandomCombine, self).__init__() - assert pure_prob >= 0 and pure_prob <= 1 - assert final_weight > 0 and final_weight < 1 - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev= stddev - - self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - - - def forward(self, inputs: Sequence[Tensor], - warmup_mode: bool) -> Tensor: - """ - Forward function. - Args: - inputs: a list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - a Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not (self.training and warmup_mode): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, - num_frames) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: - """ - Return a tensor of random weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), such that - ans.sum(dim=1) is all ones. - - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) - - def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with - exactly one weight equal to 1.0 on each frame. - """ - - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, - final, nonfinal) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) - return ans + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) - def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that - sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. - """ - logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev - logprobs[:,-1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") - num_inputs = 3 - num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev) - - x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - - y = m(x, True) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -if __name__ == '__main__': - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup_mode=True) +def identity(x): + return x diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index db51fb1cd7..b82fed37b7 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -17,9 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional -from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -55,7 +52,7 @@ def __init__( 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() - self.embedding = ScaledEmbedding( + self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, @@ -65,7 +62,7 @@ def __init__( assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: - self.conv = ScaledConv1d( + self.conv = nn.Conv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, @@ -85,7 +82,6 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: Returns: Return a tensor of shape (N, U, embedding_dim). """ - y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) @@ -100,139 +96,3 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out - - - -class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - - - - def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - scale = (self.scale * self.scale_speed).exp() - if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale - else: - return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index 3d218dcd04..257facce4f 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 241f405b60..b0ba7fd83f 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from subsampling import ScaledLinear + class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): @@ -24,7 +24,7 @@ def __init__(self, input_dim: int, output_dim: int): self.input_dim = input_dim self.output_dim = output_dim - self.output_linear = ScaledLinear(input_dim, output_dim) + self.output_linear = nn.Linear(input_dim, output_dim) def forward( self, diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index fc16f2631f..8281e1fb5f 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -65,7 +65,6 @@ def forward( x_lens: torch.Tensor, y: k2.RaggedTensor, modified_transducer_prob: float = 0.0, - warmup_mode: bool = False ) -> torch.Tensor: """ Args: @@ -88,7 +87,7 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index fa04109735..d6827c17cf 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -111,8 +111,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. - default="transducer_stateless/randcombine1_expscale3_rework2d", + default="transducer_stateless/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -223,7 +222,6 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 - "warmup_minibatches": 3000, # use warmup mode for 3k minibatches. # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -381,7 +379,6 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - is_warmup_mode: bool = False ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -418,7 +415,6 @@ def compute_loss( x_lens=feature_lens, y=y, modified_transducer_prob=params.modified_transducer_prob, - warmup_mode=is_warmup_mode ) assert loss.requires_grad == is_training @@ -455,7 +451,6 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, - is_warmup_mode=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -517,7 +512,6 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - is_warmup_mode=(params.batch_idx_train Date: Mon, 4 Apr 2022 13:34:43 +0800 Subject: [PATCH 155/185] Fix to joiner to allow different dims --- egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 64752b9a06..a1226f7127 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -47,7 +47,7 @@ def forward( Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 - assert encoder_out.shape == decoder_out.shape + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) From 99e9d6c4b8ab035d9c1962fc5b6086586d336090 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 13:37:10 +0800 Subject: [PATCH 156/185] Some cleanups --- .../ASR/conformer_ctc/subsampling.py | 422 +----------------- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 - .../ASR/transducer_stateless/diagnostics.py | 338 -------------- 3 files changed, 5 insertions(+), 757 deletions(-) delete mode 100644 egs/librispeech/ASR/transducer_stateless/diagnostics.py diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 0a39b0f336..542fb0364e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -17,8 +17,6 @@ import torch import torch.nn as nn -from torch import Tensor -from typing import Tuple class Conv2dSubsampling(nn.Module): @@ -44,27 +42,16 @@ def __init__(self, idim: int, odim: int) -> None: assert idim >= 7 super().__init__() self.conv = nn.Sequential( - ScaledConv2d( + nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( + nn.ReLU(), + nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), + nn.ReLU(), ) - self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -83,8 +70,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) return x @@ -174,400 +159,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) return x - - - - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) - - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) - - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_speed: a constant that determines how fast "eps" learns; - with Adam and variants, this should probably be >= 1, - e.g. 5.0. For SGD and variants, probably a value less than one, - like 0.1, would be suitable, to prevent instability. - """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_speed: float = 5.0): - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_speed = eps_speed - if learn_eps: - self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) - else: - self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) - - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - (self.eps * self.eps_speed).exp()) ** -0.5 - return x * scales - - - - -class ScaledLinear(nn.Linear): - """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * (self.weight_scale * self.scale_speed).exp() - bias = self.bias * (self.bias_scale * self.scale_speed).exp() - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - scale_speed: a factor that affects how fast the weight_scale - and bias_scale learn; this value is suitable for Adam-type - optimizers. - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - - Note: it uses the default initialization for the weight and bias, - inherited from nn.Linear. For modules with small fan-in, this - may be larger than optimal. - """ - def __init__(self, *args, - scale_speed: float = 5.0, - initial_scale: float = 1.0, - **kwargs): - super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - self.scale_speed = scale_speed - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - - self._reset_parameters() # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) - - -class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, - initial_scale=1.0, **kwargs): - super(ScaledConv1d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - - - -class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): - super(ScaledConv2d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - - Args: - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): - super(ActivationBalancer, self).__init__() - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - - def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) - - -def _double_swish(x: Tensor) -> Tensor: - # double-swish, implemented/approximated as offset-swish - return x * torch.sigmoid(x - 1.0) - -class DoubleSwishFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x.detach()) - return _double_swish(x) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - # TODO: can make this more efficient. - x, = ctx.saved_tensors - x.requires_grad = True - with torch.enable_grad(): - y = _double_swish(x) - y.backward(gradient=y_grad) - return x.grad - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - return DoubleSwishFunction.apply(x) - - - -def _test_deriv_balancer_sign(): - channel_dim = 0 - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_deriv_balancer_sign: x = ", x) - print("_test_deriv_balancer_sign: y grad = ", y_grad) - print("_test_deriv_balancer_sign: x grad = ", x.grad) - -def _test_deriv_balancer_magnitude(): - channel_dim = 0 - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_deriv_balancer_magnitude: x = ", x) - print("_test_deriv_balancer_magnitude: y grad = ", y_grad) - print("_test_deriv_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - - - - -if __name__ == '__main__': - _test_deriv_balancer_sign() - _test_deriv_balancer_magnitude() - _test_basic_norm() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 477afcecba..8dd1459ca5 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -22,8 +22,6 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional -import torch -from lhotse.utils import fix_random_seed import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py deleted file mode 100644 index 7fd83d56bc..0000000000 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ /dev/null @@ -1,338 +0,0 @@ -import torch -from torch import Tensor -from torch import nn -import math -import random -from typing import Tuple, List - - -class TensorDiagnosticOptions(object): - """ - Options object for tensor diagnostics: - - Args: - memory_limit: the maximum number of bytes we store per tensor (limits how many copies - of the tensor we cache). - max_eig_dim: the maximum dimension for which we print out eigenvalues - (limited for speed reasons). - """ - def __init__(self, - memory_limit: int = (2 ** 20), - max_eig_dim: int = 512): - - self.memory_limit = memory_limit - self.max_eig_dim = max_eig_dim - - def dim_is_summarized(self, size: int): - return size > 10 and size != 31 - - - -def get_tensor_stats(x: Tensor, dim: int, - stats_type: str) -> Tuple[Tensor, int]: - """ - Returns the specified transformation of the Tensor (either x or x.abs() - or (x > 0), summed over all but the index `dim`. - - Args: - x: Tensor, tensor to be analyzed - dim: dimension with 0 <= dim < x.ndim - stats_type: - "abs" -> take abs() before summing - "positive" -> take (x > 0) before summing - "rms" -> square before summing, we'll take sqrt later - "value -> just sum x itself - Returns (stats, count) - where stats is a Tensor of shape (x.shape[dim],), and the count - is an integer saying how many items were counted in each element - of stats. - """ - count = x.numel() // x.shape[dim] - - if stats_type == "eigs": - x = x.transpose(dim, -1) - x = x.reshape(-1, x.shape[-1]) - # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. - return torch.matmul(x.transpose(0, 1), x), count - elif stats_type == "abs": - x = x.abs() - elif stats_type == "rms": - x = x ** 2 - elif stats_type == "positive": - x = (x > 0).to(dtype=torch.float) - else: - assert stats_type == "value" - - sum_dims = [ d for d in range(x.ndim) if d != dim ] - if len(sum_dims) > 0: - x = torch.sum(x, dim=sum_dims) - x = x.flatten() - return x, count - -def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions, - sizes_same: bool, - stats_type: str): - """ - This function gets diagnostics for a dimension of a module. - Args: - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: options object - sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", - imdictates the type of stats - we accumulate, abs is mean absolute value, "positive" - is proportion of positive to nonnegative values, "eigs" - is eigenvalues after doing outer product on this dim, sum - over all other dimes. - Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. Will return the empty string if the diagnostics did - not make sense to print out for this dimension, e.g. dimension - mismatch and stats_type == "eigs" - """ - # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] - stats = [ x[0] for x in stats_and_counts ] - counts = [ x[1] for x in stats_and_counts ] - - if stats_type == "eigs": - try: - stats = torch.stack(stats).sum(dim=0) - except: - return '' - count = sum(counts) - stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - elif sizes_same: - stats = torch.stack(stats).sum(dim=0) - count = sum(counts) - stats = stats / count - else: - stats = [ x[0] / x[1] for x in stats_and_counts ] - stats = torch.cat(stats, dim=0) - if stats_type == 'rms': - stats = stats.sqrt() - - # if `summarize` we print percentiles of the stats; else, - # we print out individual elements. - summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) - if summarize: - # print out percentiles. - stats = stats.sort()[0] - num_percentiles = 10 - size = stats.numel() - percentiles = [] - for i in range(num_percentiles + 1): - index = (i * (size - 1)) // num_percentiles - percentiles.append(stats[index].item()) - percentiles = [ '%.2g' % x for x in percentiles ] - percentiles = ' '.join(percentiles) - ans = f'percentiles: [{percentiles}]' - else: - ans = stats.tolist() - ans = [ '%.2g' % x for x in ans ] - ans = '[' + ' '.join(ans) + ']' - if stats_type == "value": - # This norm is useful because it is strictly less than the largest - # sqrt(eigenvalue) of the variance, which we print out, and shows, - # speaking in an approximate way, how much of that largest eigenvalue - # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' - else: - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f', mean={mean:.2g}, rms={rms:.2g}' - return ans - - - -def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions): - ndim = tensors[0].ndim - if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] - if tensors[0].shape[dim] <= options.max_eig_dim: - stats_types.append("eigs") - else: - stats_types = [ "value", "abs" ] - - for stats_type in stats_types: - sizes = [ x.shape[dim] for x in tensors ] - sizes_same = all([ x == sizes[0] for x in sizes ]) - s = get_diagnostics_for_dim(dim, tensors, - options, sizes_same, - stats_type) - if s == '': - continue - - min_size = min(sizes) - max_size = max(sizes) - size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "abs" or "positive". - print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") - - -class TensorDiagnostic(object): - """ - This class is not directly used by the user, it is responsible for collecting - diagnostics for a single parameter tensor of a torch.Module. - """ - def __init__(self, - opts: TensorDiagnosticOptions, - name: str): - self.name = name - self.opts = opts - self.saved_tensors = [] - - def accumulate(self, x): - if isinstance(x, Tuple): - x = x[0] - if not isinstance(x, Tensor): - return - if x.device == torch.device('cpu'): - x = x.detach().clone() - else: - x = x.detach().to('cpu', non_blocking=True) - self.saved_tensors.append(x) - l = len(self.saved_tensors) - if l & (l - 1) == 0: # power of 2.. - self._limit_memory() - - def _limit_memory(self): - if len(self.saved_tensors) > 1024: - self.saved_tensors = self.saved_tensors[-1024:] - return - - tot_mem = 0.0 - for i in reversed(range(len(self.saved_tensors))): - tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() - if tot_mem > self.opts.memory_limit: - self.saved_tensors = self.saved_tensors[i:] - return - - def print_diagnostics(self): - if len(self.saved_tensors) == 0: - print("{name}: no stats".format(name=self.name)) - return - if self.saved_tensors[0].ndim == 0: - # ensure there is at least one dim. - self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] - - try: - device = torch.device('cuda') - torch.ones(1, 1, device) - except: - device = torch.device('cpu') - - ndim = self.saved_tensors[0].ndim - tensors = [x.to(device) for x in self.saved_tensors] - for dim in range(ndim): - print_diagnostics_for_dim(self.name, dim, - tensors, - self.opts) - - -class ModelDiagnostic(object): - def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): - self.diagnostics = dict() - self.opts = opts - - def __getitem__(self, name: str): - if name not in self.diagnostics: - self.diagnostics[name] = TensorDiagnostic(self.opts, name) - return self.diagnostics[name] - - def print_diagnostics(self): - for k in sorted(self.diagnostics.keys()): - self.diagnostics[k].print_diagnostics() - - - -def attach_diagnostics(model: nn.Module, - opts: TensorDiagnosticOptions) -> ModelDiagnostic: - ans = ModelDiagnostic(opts) - for name, module in model.named_modules(): - if name == '': - name = "" - forward_diagnostic = TensorDiagnostic(opts, name + ".output") - backward_diagnostic = TensorDiagnostic(opts, name + ".grad") - - - # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, - # ensures that we use the current values. (matters for name, since - # the variable gets overwritten). these closures don't really capture - # by value, only by "the final value the variable got in the function" :-( - def forward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): - if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate(_output) - elif isinstance(_output, tuple): - for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) - - def backward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): - if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate(_output) - elif isinstance(_output, tuple): - for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) - - module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) - - for name, parameter in model.named_parameters(): - - def param_backward_hook(grad, - _parameter=parameter, - _model_diagnostic=ans, - _name=name): - _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) - _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) - - parameter.register_hook(param_backward_hook) - return ans - - - -def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, 512) - - diagnostic = TensorDiagnostic(opts, "foo") - - for _ in range(10): - diagnostic.accumulate(torch.randn(50, 100) * 10.0) - - diagnostic.print_diagnostics() - - model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) - - diagnostic = attach_diagnostics(model, opts) - for _ in range(10): - T = random.randint(200, 300) - x = torch.randn(T, 100) - y = model(x) - y.sum().backward() - - diagnostic.print_diagnostics() - - - -if __name__ == '__main__': - _test_tensor_diagnostic() - - -def _test_func(): - ans = [] - for i in range(10): - x = list() - x.append(i) - def func(): - return x - ans.append(func) - return ans From a5bbcd7b71a9f519e8f6d7830e8265a2d1fc490c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 14:10:38 +0800 Subject: [PATCH 157/185] Make training more efficient, avoid redoing some projections. --- .../ASR/pruned_transducer_stateless2/joiner.py | 12 ++++++++++-- .../ASR/pruned_transducer_stateless2/model.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index a1226f7127..752a5f7742 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -35,7 +35,8 @@ def __init__(self, self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + project_input: bool = True ) -> torch.Tensor: """ Args: @@ -43,13 +44,20 @@ def forward( Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. Returns: Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + logit = encoder_out + decoder_out logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 1dd20c5463..a9178c8b38 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -164,11 +164,17 @@ def forward( # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=encoder_out, lm=decoder_out, ranges=ranges + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges ) # logits : [B, T, prune_range, vocab_size] - logits = self.joiner(am_pruned, lm_pruned) + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, + project_input=False) pruned_loss = k2.rnnt_loss_pruned( logits=logits, From 4929e4cf32f93860aad273223c86a6dc98d611df Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 17:09:25 +0800 Subject: [PATCH 158/185] Change how warm-step is set --- .../ASR/pruned_transducer_stateless2/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index e8fbb6a716..bf7f23fab2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -147,6 +147,13 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--warm-step", + type=float, + default=60000, + help="The number of warmup steps for the (modified) Noam optimizer", + ) + parser.add_argument( "--context-size", type=int, @@ -296,7 +303,6 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } @@ -709,7 +715,6 @@ def run(rank, world_size, args): params.update(vars(args)) if params.full_libri is False: params.valid_interval = 1600 - params.warm_step = 30000 fix_random_seed(params.seed) if world_size > 1: From 72f4a673b106fefc9c88841a3e4ff3a9d1d6fd88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 20:21:34 +0800 Subject: [PATCH 159/185] First draft of new approach to learning rates + init --- .../pruned_transducer_stateless2/conformer.py | 87 ------ .../ASR/pruned_transducer_stateless2/optim.py | 254 ++++++++++++++++++ .../pruned_transducer_stateless2/scaling.py | 12 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +++- 4 files changed, 299 insertions(+), 104 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/optim.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0deb960ad3..4797cce08f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1017,93 +1017,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - if __name__ == '__main__': feature_dim = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py new file mode 100644 index 0000000000..edbebcceb6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -0,0 +1,254 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class Eve(Optimizer): + r""" + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular specified value (generally 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, + target_rms=0.1): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict(lr=lr, betas=betas, eps=eps, + target_rms=target_rms) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + target_rms = group['target_rms'] + delta = exp_avg / denom + + # we'll be doing: p += delta * step_size. + # In the normal case delta_rms (the rms value of the elements of + # delta) will be very close to 1.0, but we compute it here so + # that if we don't use a particular parameter, its value won't + # shrink to zero. + # delta_var is the expected change in the variance of the parameter + # values, i.e. of E[param_elem^2], due to this step. It will + # be close to 1. + + # Let us define: + # delta_var_from_update = (delta**2).mean() * step_size * step_size + + # Suppose we are going to shrinkage with a small value epsilon (not the + # same as the eps above!), i.e. param *= (1-epsilon). Then + # if E[param_elem^2] == target_rms^2, + # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2), + # which we can put as: + # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. + # Setting delta_var_from_shrinkage = -delta_var_from_update + # because we want them to cancel, + # delta_var_from_update = 2 epsilon target_rms^2, or: + # epsilon = delta_var_from_update / (2 * target_rms^2) + # = (delta**2).mean() * 0.5 * (step_size / target_rms)**2. + # Note: step_size is close to the learning rate. For an example, if + # lr = 1.0e-04 and target_rms == 0.1, then in the normal case where + # (delta**2).mean() == 1, we will have: + # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. + # Note that this is close to the "traditional" value used for weight + # decay. + + # this is the weight-decay amount... + weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) + + p.mul_(1 - weight_decay) + p.add_(delta, alpha=-step_size) + + return loss + + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 4c45205ce3..33b4ad9089 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -158,7 +158,10 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + # we plan to use Eve as the optimizer, which will eventually make the stddev approach + # 0.1 as that's the target_rms we set, but we initialize with a larger stddev + # to have the same effect as a warm-up period. + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -196,7 +199,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -241,7 +244,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -476,9 +479,8 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona self.reset_parameters(initial_speed) - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.01 / initial_speed + std = 0.5 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index bf7f23fab2..9d074fdd47 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -28,7 +28,10 @@ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ - --lr-factor 1.5 + --initial-lr 0.002 \ + --lr-decay-steps 10000 \ + --num-lr-decays 4 + """ @@ -52,6 +55,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from optim import Eve from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -141,17 +145,24 @@ def get_parser(): ) parser.add_argument( - "--lr-factor", + "--initial-lr", + type=float, + default=0.002, + help="The initial learning rate", + ) + + parser.add_argument( + "--lr-decay-steps", type=float, - default=5.0, - help="The lr_factor for Noam optimizer", + default=5000, + help="The number of steps before we decay (halve) the learning rate", ) parser.add_argument( - "--warm-step", + "--num-lr-decays", type=float, - default=60000, - help="The number of warmup steps for the (modified) Noam optimizer", + default=4, + help="The total number of times we decay (halve) the learning rate" ) parser.add_argument( @@ -426,6 +437,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -449,6 +461,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, + scheduler=scheduler, sampler=sampler, rank=rank, ) @@ -574,6 +587,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -594,6 +608,8 @@ def train_one_epoch( The model for training. optimizer: The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: @@ -636,6 +652,7 @@ def train_one_epoch( loss.backward() optimizer.step() optimizer.zero_grad() + lr_scheduler.step() if params.print_diagnostics and batch_idx == 5: return @@ -651,6 +668,7 @@ def train_one_epoch( model=model, params=params, optimizer=optimizer, + scheduler=scheduler, sampler=train_dl.sampler, rank=rank, ) @@ -756,17 +774,24 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - optimizer = Noam( + optimizer = Eve( model.parameters(), - model_size=params.encoder_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) + lr=params.initial_lr, betas=(0.9, 0.98), + eps=1e-9, target_rms=0.1) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + [ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], + gamma=0.5) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) + if checkpoints and "scheduler" in checkpoints: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( @@ -839,6 +864,7 @@ def remove_short_and_long_utt(c: Cut): params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sp=sp, train_dl=train_dl, valid_dl=valid_dl, From d1f2f934605cebe7e438f483e4f7beee6bf0966e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 22:40:18 +0800 Subject: [PATCH 160/185] Some fixes.. --- .../ASR/pruned_transducer_stateless2/optim.py | 97 +------------------ .../ASR/pruned_transducer_stateless2/train.py | 12 ++- 2 files changed, 12 insertions(+), 97 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index edbebcceb6..6f19807dc9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -27,7 +27,7 @@ class Eve(Optimizer): r""" Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular specified value (generally 0.1). This is + rms of the parameters approach a particular specified value (we suggest 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. @@ -120,7 +120,7 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) step_size = group['lr'] / bias_correction1 target_rms = group['target_rms'] @@ -141,7 +141,7 @@ def step(self, closure=None): # Suppose we are going to shrinkage with a small value epsilon (not the # same as the eps above!), i.e. param *= (1-epsilon). Then # if E[param_elem^2] == target_rms^2, - # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2), + # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), # which we can put as: # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. # Setting delta_var_from_shrinkage = -delta_var_from_update @@ -157,98 +157,9 @@ def step(self, closure=None): # decay. # this is the weight-decay amount... - weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) return loss - - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 9d074fdd47..9f73c8fbc6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -48,7 +48,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer, Noam +from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -437,7 +437,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -652,7 +652,7 @@ def train_one_epoch( loss.backward() optimizer.step() optimizer.zero_grad() - lr_scheduler.step() + scheduler.step() if params.print_diagnostics and batch_idx == 5: return @@ -848,7 +848,7 @@ def remove_short_and_long_utt(c: Cut): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - cur_lr = optimizer._rate + cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train @@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. loss, _ = compute_loss( params=params, model=model, sp=sp, batch=batch, is_training=True, + warmup = 0.0 ) loss.backward() optimizer.step() From 179d0605ea235fa92fee7289a3db1374f3ec2bcf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 23:34:39 +0800 Subject: [PATCH 161/185] Change initialization to 0.25 --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 3 ++- .../ASR/pruned_transducer_stateless2/scaling.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 6f19807dc9..17450def8c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -140,7 +140,8 @@ def step(self, closure=None): # Suppose we are going to shrinkage with a small value epsilon (not the # same as the eps above!), i.e. param *= (1-epsilon). Then - # if E[param_elem^2] == target_rms^2, + # if E[param_elem^2] == target_rms^2 (because we desire equilibrium when + # the RMS of the parameters equals target_rms), it follows that # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), # which we can put as: # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 33b4ad9089..4b91bb04c1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -161,7 +161,7 @@ def _reset_parameters(self, initial_speed: float): # we plan to use Eve as the optimizer, which will eventually make the stddev approach # 0.1 as that's the target_rms we set, but we initialize with a larger stddev # to have the same effect as a warm-up period. - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -199,7 +199,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -244,7 +244,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -480,7 +480,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.5 / initial_speed + std = 0.25 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 234366e51c450cb95e18acdcf9d6544d74155885 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 00:18:36 +0800 Subject: [PATCH 162/185] Fix type of parameter --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 9f73c8fbc6..83558a72b5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -160,7 +160,7 @@ def get_parser(): parser.add_argument( "--num-lr-decays", - type=float, + type=int, default=4, help="The total number of times we decay (halve) the learning rate" ) From 2b0727a355d73205c1a91b770902c0da04aec958 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 00:31:28 +0800 Subject: [PATCH 163/185] Fix weight decay formula by adding 1/1-beta --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 17450def8c..607a4e3505 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -156,9 +156,12 @@ def step(self, closure=None): # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. # Note that this is close to the "traditional" value used for weight # decay. - + # # this is the weight-decay amount... - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2) + # + # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive + # frames being correlated. I have to figure out the justification. + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta))) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) From 47d49f29d78742e9d22850c08bdd094d1c4bb6f9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 00:31:55 +0800 Subject: [PATCH 164/185] Fix weight decay formula by adding 1/1-beta --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 607a4e3505..eb77769389 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -161,7 +161,7 @@ def step(self, closure=None): # # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive # frames being correlated. I have to figure out the justification. - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta))) + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1))) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) From 1548cc7462a59da00f3bddad7b51166c5a0a3b09 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 11:19:40 +0800 Subject: [PATCH 165/185] Fix checkpoint-writing --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 83558a72b5..c63c849c40 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -376,6 +376,7 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -395,6 +396,8 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. + scheduler: + The scheduler that we are using. Returns: Return a dict containing previously saved training info. """ @@ -411,6 +414,7 @@ def load_checkpoint_if_available( filename, model=model, optimizer=optimizer, + scheduler=scheduler, ) keys = [ @@ -784,6 +788,7 @@ def run(rank, world_size, args): gamma=0.5) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) @@ -792,7 +797,6 @@ def run(rank, world_size, args): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) - if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 2 ** 22 @@ -881,6 +885,7 @@ def remove_short_and_long_utt(c: Cut): params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sampler=train_dl.sampler, rank=rank, ) From 0f5957394bd346c9a0207b66110b7a1bce10f643 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 12:58:43 +0800 Subject: [PATCH 166/185] Fix to reading scheudler from optim --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c63c849c40..348e2dd472 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -793,7 +793,7 @@ def run(rank, world_size, args): logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) - if checkpoints and "scheduler" in checkpoints: + if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None: logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) From c3169222aee9db0780379a70f9dea9daf5254d78 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:23:02 +0800 Subject: [PATCH 167/185] Simplified optimizer, rework somet things.. --- .../ASR/pruned_transducer_stateless2/optim.py | 74 ++++++++----------- .../ASR/pruned_transducer_stateless2/train.py | 22 +++--- 2 files changed, 39 insertions(+), 57 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index eb77769389..b17ebba7ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -27,7 +27,7 @@ class Eve(Optimizer): r""" Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular specified value (we suggest 0.1). This is + rms of the parameters approach a particular target_rms (default: 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. @@ -43,10 +43,13 @@ class Eve(Optimizer): running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -57,7 +60,7 @@ class Eve(Optimizer): """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - target_rms=0.1): + weight_decay=3e-4, target_rms=0.1): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -67,9 +70,12 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, target_rms=target_rms) super(Eve, self).__init__(params, defaults) @@ -94,6 +100,9 @@ def step(self, closure=None): if p.grad is None: continue + + + # Perform optimization step grad = p.grad if grad.is_sparse: @@ -124,46 +133,21 @@ def step(self, closure=None): step_size = group['lr'] / bias_correction1 target_rms = group['target_rms'] + weight_decay = group['weight_decay'] delta = exp_avg / denom - # we'll be doing: p += delta * step_size. - # In the normal case delta_rms (the rms value of the elements of - # delta) will be very close to 1.0, but we compute it here so - # that if we don't use a particular parameter, its value won't - # shrink to zero. - # delta_var is the expected change in the variance of the parameter - # values, i.e. of E[param_elem^2], due to this step. It will - # be close to 1. - - # Let us define: - # delta_var_from_update = (delta**2).mean() * step_size * step_size - - # Suppose we are going to shrinkage with a small value epsilon (not the - # same as the eps above!), i.e. param *= (1-epsilon). Then - # if E[param_elem^2] == target_rms^2 (because we desire equilibrium when - # the RMS of the parameters equals target_rms), it follows that - # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), - # which we can put as: - # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. - # Setting delta_var_from_shrinkage = -delta_var_from_update - # because we want them to cancel, - # delta_var_from_update = 2 epsilon target_rms^2, or: - # epsilon = delta_var_from_update / (2 * target_rms^2) - # = (delta**2).mean() * 0.5 * (step_size / target_rms)**2. - # Note: step_size is close to the learning rate. For an example, if - # lr = 1.0e-04 and target_rms == 0.1, then in the normal case where - # (delta**2).mean() == 1, we will have: - # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. - # Note that this is close to the "traditional" value used for weight - # decay. - # - # this is the weight-decay amount... - # - # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive - # frames being correlated. I have to figure out the justification. - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1))) - - p.mul_(1 - weight_decay) - p.add_(delta, alpha=-step_size) + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" (which are scalar). + is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5))) + p.mul_(1 - (weight_decay * is_below_target_rms)) + p.addcdiv_(exp_avg, denom, value=-step_size) return loss + +# Note on avg-change per epoch.. +# suppose epoch is 4k iters. +# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, +# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) +# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. +# +# .. 6e-05 is 1/5 of that... diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 348e2dd472..1340e09505 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -152,17 +152,17 @@ def get_parser(): ) parser.add_argument( - "--lr-decay-steps", + "--lr-num-steps", type=float, - default=5000, - help="The number of steps before we decay (halve) the learning rate", + default=3000, + help="Number of steps before we start to significantly decay the learning rate", ) parser.add_argument( - "--num-lr-decays", - type=int, - default=4, - help="The total number of times we decay (halve) the learning rate" + "--lr-power", + type=float, + default=0.5, + help="Power in LR-setting rule", ) parser.add_argument( @@ -781,12 +781,10 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), lr=params.initial_lr, betas=(0.9, 0.98), - eps=1e-9, target_rms=0.1) - scheduler = torch.optim.lr_scheduler.MultiStepLR( + eps=1e-9, weight_decay=3e-04, target_rms=0.1) + scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - [ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], - gamma=0.5) - + lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power)) if checkpoints and "optimizer" in checkpoints: From ed8eba91e14f35107fcfe52137015e0806ae6532 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:24:09 +0800 Subject: [PATCH 168/185] Reduce model_warm_step from 4k to 3k --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1340e09505..45b3ca1686 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -314,7 +314,7 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "model_warm_step": 4000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) From d1a669162caf39fe15e318bfd3f51636cc8826bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:31:52 +0800 Subject: [PATCH 169/185] Fix bug in lambda --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 45b3ca1686..3b8f0499f8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -784,7 +784,7 @@ def run(rank, world_size, args): eps=1e-9, weight_decay=3e-04, target_rms=0.1) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power)) + lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) if checkpoints and "optimizer" in checkpoints: From 25724b5ce9f786f644e662de6e2636add523ce89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:49:35 +0800 Subject: [PATCH 170/185] Bug-fix RE sign of target_rms --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index b17ebba7ce..2b40dda45c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -137,9 +137,10 @@ def step(self, closure=None): delta = exp_avg / denom if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" (which are scalar). - is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5))) - p.mul_(1 - (weight_decay * is_below_target_rms)) + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) + p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) return loss @@ -149,5 +150,6 @@ def step(self, closure=None): # if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, # then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) # = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. +# Suggested lr_schedule? # # .. 6e-05 is 1/5 of that... From 2545237eb3ff801364151cd8a82ed01896445a17 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 18:00:54 +0800 Subject: [PATCH 171/185] Changing initial_speed from 0.25 to 01 --- .../ASR/pruned_transducer_stateless2/scaling.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 4b91bb04c1..98a56ce775 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -158,10 +158,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear def _reset_parameters(self, initial_speed: float): - # we plan to use Eve as the optimizer, which will eventually make the stddev approach - # 0.1 as that's the target_rms we set, but we initialize with a larger stddev - # to have the same effect as a warm-up period. - std = 0.25 / initial_speed + std = 0.1 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -199,7 +196,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.25 / initial_speed + std = 0.1 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -244,7 +241,7 @@ def __init__(self, *args, self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.25 / initial_speed + std = 0.1 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -480,7 +477,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.25 / initial_speed + std = 0.1 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From a41e93437c608f2061f72796c7260e3d5ff7bc7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 Apr 2022 12:36:58 +0800 Subject: [PATCH 172/185] Change some defaults in LR-setting rule. --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2b40dda45c..a2e0463dab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -60,7 +60,7 @@ class Eve(Optimizer): """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - weight_decay=3e-4, target_rms=0.1): + weight_decay=1e-3, target_rms=0.1): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 3b8f0499f8..306a2195b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -161,7 +161,7 @@ def get_parser(): parser.add_argument( "--lr-power", type=float, - default=0.5, + default=0.75, help="Power in LR-setting rule", ) @@ -780,8 +780,7 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), - lr=params.initial_lr, betas=(0.9, 0.98), - eps=1e-9, weight_decay=3e-04, target_rms=0.1) + lr=params.initial_lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) From 61486a0f76d79e941257f87efc7b10188fb48b44 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 Apr 2022 13:17:26 +0800 Subject: [PATCH 173/185] Remove initial_speed --- .../ASR/pruned_transducer_stateless2/conformer.py | 8 -------- .../ASR/pruned_transducer_stateless2/decoder.py | 6 ------ .../ASR/pruned_transducer_stateless2/joiner.py | 3 --- 3 files changed, 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 4797cce08f..94c6aa90c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -956,30 +956,22 @@ def __init__(self, in_channels: int, assert in_channels >= 7 super().__init__() - # This initial_speed is to slightly slow down the relative speed of - # training during the warmup phase by increasing the magnitude of the - # initial parameter values. The intention is to allow us to - # use a higher lr_factor. - initial_speed = 0.5 self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=1, - initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, stride=2, - initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=2, - initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 3291ad8775..c23568ae9c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -56,16 +56,10 @@ def __init__( """ super().__init__() - # This initial_speed is to slightly slow down the relative speed of - # training during the warmup phase by increasing the magnitude of the - # initial parameter values. The intention is to allow us to - # use a higher lr_factor. - initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, padding_idx=blank_id, - initial_speed=initial_speed ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 752a5f7742..2299a0a8cb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -27,9 +27,6 @@ def __init__(self, vocab_size: int): super().__init__() - # We don't bother giving the 'initial_speed' arg to the decoder - # submodules, because it does not affect the initial convergence of the - # system (only the simple joiner is involved in that). self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) self.output_linear = ScaledLinear(joiner_dim, vocab_size) From 6ee32cf7afd110783b5872e431e30308583abb21 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Apr 2022 16:10:06 +0800 Subject: [PATCH 174/185] Set new scheduler --- .../ASR/pruned_transducer_stateless2/train.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 306a2195b2..e06db45c08 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -28,15 +28,17 @@ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ - --initial-lr 0.002 \ - --lr-decay-steps 10000 \ - --num-lr-decays 4 + --initial-lr 0.003 \ + --lr-begin-steps 20000 \ + --lr-end-steps 50000 + """ import argparse import logging +import math import warnings from pathlib import Path from shutil import copyfile @@ -147,22 +149,22 @@ def get_parser(): parser.add_argument( "--initial-lr", type=float, - default=0.002, + default=0.003, help="The initial learning rate", ) parser.add_argument( - "--lr-num-steps", + "--lr-begin-steps", type=float, - default=3000, - help="Number of steps before we start to significantly decay the learning rate", + default=20000, + help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( - "--lr-power", + "--lr-end-steps", type=float, - default=0.75, - help="Power in LR-setting rule", + default=50000, + help="Number of steps that affects how rapidly the learning rate finally decreases" ) parser.add_argument( @@ -783,7 +785,8 @@ def run(rank, world_size, args): lr=params.initial_lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) + lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * + math.exp(-step / params.lr_end_steps))) if checkpoints and "optimizer" in checkpoints: From f587cd527dd0075349d2a2f3502d3f62945679be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Apr 2022 16:24:21 +0800 Subject: [PATCH 175/185] Change exponential part of lrate to be epoch based --- .../ASR/pruned_transducer_stateless2/train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index e06db45c08..d5da5d0e9d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -27,10 +27,8 @@ --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ - --max-duration 300 \ - --initial-lr 0.003 \ - --lr-begin-steps 20000 \ - --lr-end-steps 50000 + --max-duration 300 + """ @@ -161,10 +159,10 @@ def get_parser(): ) parser.add_argument( - "--lr-end-steps", + "--lr-end-epochs", type=float, - default=50000, - help="Number of steps that affects how rapidly the learning rate finally decreases" + default=10, + help="Number of epochs that affects how rapidly the learning rate finally decreases" ) parser.add_argument( @@ -783,10 +781,13 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), lr=params.initial_lr) + + # The `epoch` variable in the lambda expression binds to the value below + # in `for epoch in range(params.start_epoch, params.num_epochs):`. scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - math.exp(-step / params.lr_end_steps))) + math.exp(-epoch / params.lr_end_epochs))) if checkpoints and "optimizer" in checkpoints: From 0f8ee68af22657c90c5e2762a5e48e1d09b7ce0c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Apr 2022 16:53:42 +0800 Subject: [PATCH 176/185] Fix bug --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d5da5d0e9d..0384692820 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -784,6 +784,7 @@ def run(rank, world_size, args): # The `epoch` variable in the lambda expression binds to the value below # in `for epoch in range(params.start_epoch, params.num_epochs):`. + epoch = 0 scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * From db72aee1f0e4987bf79b0578967f1c45be562dbc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Apr 2022 18:15:56 +0800 Subject: [PATCH 177/185] Set 2n rule.. --- .../ASR/pruned_transducer_stateless2/train.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 0384692820..92509f4eca 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -154,15 +154,17 @@ def get_parser(): parser.add_argument( "--lr-begin-steps", type=float, - default=20000, + default=25000, help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( "--lr-end-epochs", type=float, - default=10, - help="Number of epochs that affects how rapidly the learning rate finally decreases" + default=-1, + help="""Number of epochs that affects how rapidly the learning rate finally decreases; + if -1, will be set the same as --num-epochs + """ ) parser.add_argument( @@ -783,12 +785,14 @@ def run(rank, world_size, args): lr=params.initial_lr) # The `epoch` variable in the lambda expression binds to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. + # in `for epoch in range(params.start_epoch, params.num_epochs):`. But set it to 0 + # here to avoid crash in constructor. epoch = 0 + lr_end_epochs = params.lr_end_epochs if params.lr_end_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - math.exp(-epoch / params.lr_end_epochs))) + ((epoch + lr_end_epochs) / lr_end_epochs) ** -2.0)) if checkpoints and "optimizer" in checkpoints: From 4d41ee0caad855354bf95b6f1cab87072060a974 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Apr 2022 18:37:03 +0800 Subject: [PATCH 178/185] Implement 2o schedule --- .../ASR/pruned_transducer_stateless2/optim.py | 12 ------------ .../ASR/pruned_transducer_stateless2/train.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index a2e0463dab..e47c08657e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -100,9 +100,6 @@ def step(self, closure=None): if p.grad is None: continue - - - # Perform optimization step grad = p.grad if grad.is_sparse: @@ -144,12 +141,3 @@ def step(self, closure=None): p.addcdiv_(exp_avg, denom, value=-step_size) return loss - -# Note on avg-change per epoch.. -# suppose epoch is 4k iters. -# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, -# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) -# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. -# Suggested lr_schedule? -# -# .. 6e-05 is 1/5 of that... diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 92509f4eca..a114dd8f1e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -154,15 +154,15 @@ def get_parser(): parser.add_argument( "--lr-begin-steps", type=float, - default=25000, + default=5000, help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( - "--lr-end-epochs", + "--lr-epochs", type=float, default=-1, - help="""Number of epochs that affects how rapidly the learning rate finally decreases; + help="""Number of epochs for purposes of the learning-rate schedule; if -1, will be set the same as --num-epochs """ ) @@ -784,15 +784,15 @@ def run(rank, world_size, args): model.parameters(), lr=params.initial_lr) - # The `epoch` variable in the lambda expression binds to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. But set it to 0 + # The `epoch` variable in the lambda expression picks up to the value below + # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 # here to avoid crash in constructor. epoch = 0 - lr_end_epochs = params.lr_end_epochs if params.lr_end_epochs > 0 else params.num_epochs + lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - ((epoch + lr_end_epochs) / lr_end_epochs) ** -2.0)) + lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 * + ((epoch + lr_epochs) / lr_epochs) ** -0.5)) if checkpoints and "optimizer" in checkpoints: From da50525ca5f1cf5bb655adcac0e8ad5231aa7b5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 13:25:40 +0800 Subject: [PATCH 179/185] Make lrate rule more symmetric --- .../ASR/pruned_transducer_stateless2/train.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a114dd8f1e..a8aaa4ddef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -148,22 +148,22 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( - "--lr-begin-steps", + "--lr-steps", type=float, default=5000, - help="Number of steps that affects how rapidly the learning rate initially decreases" + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""" ) parser.add_argument( "--lr-epochs", type=float, - default=-1, - help="""Number of epochs for purposes of the learning-rate schedule; - if -1, will be set the same as --num-epochs + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. """ ) @@ -788,11 +788,10 @@ def run(rank, world_size, args): # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 # here to avoid crash in constructor. epoch = 0 - lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 * - ((epoch + lr_epochs) / lr_epochs) ** -0.5)) + lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * + (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)) if checkpoints and "optimizer" in checkpoints: From 82d58629eaa54b64b32e68cb44d519bb58e530e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 13:50:31 +0800 Subject: [PATCH 180/185] Implement 2p version of learning rate schedule. --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a8aaa4ddef..73ba17a719 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -791,7 +791,7 @@ def run(rank, world_size, args): scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * - (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)) + (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))) if checkpoints and "optimizer" in checkpoints: From d1e4ae788dcddbefd3840c3f5bbc598ec7e225b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 15:25:27 +0800 Subject: [PATCH 181/185] Refactor how learning rate is set. --- .../ASR/pruned_transducer_stateless2/optim.py | 151 +++++++++++++++++- .../ASR/pruned_transducer_stateless2/train.py | 43 ++--- icefall/checkpoint.py | 11 +- 3 files changed, 174 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index e47c08657e..4f7392d3a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -16,7 +16,7 @@ import random -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -141,3 +141,152 @@ def step(self, closure=None): p.addcdiv_(exp_avg, denom, value=-step_size) return loss + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {'base_lrs': self.base_lrs, + 'epoch': self.epoch, + 'batch': self.batch} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """ Return last computed learning rate by current scheduler. Will be a list of float. + """ + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate. + """ + if is_verbose: + print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' + f' of group {group} to {lr:.4e}.') + + +class Eden(LRScheduler): + """ + Eden scheduler. + lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) + + E.g. suggest initial-lr = 0.003 (passed to optimizer). + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6. + """ + def __init__(self, optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + verbose: bool = False): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + + def get_lr(self): + factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * + (((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) + return [ x * factor for x in self.base_lrs ] + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = Eve(m.parameters(), lr=0.003) + + scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + print("last lr = ", scheduler.get_last_lr()) + print("state dict = ", scheduler.state_dict()) + +if __name__ == '__main__': + _test_eden() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 73ba17a719..ddd2e8fb7d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -40,7 +40,7 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import k2 import sentencepiece as spm @@ -55,7 +55,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eve +from optim import Eve, Eden from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -74,6 +74,7 @@ str2bool, ) +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): parser = argparse.ArgumentParser( @@ -152,7 +153,7 @@ def get_parser(): ) parser.add_argument( - "--lr-steps", + "--lr-batches", type=float, default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. @@ -378,7 +379,7 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -443,7 +444,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -593,7 +594,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, + scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -656,17 +657,15 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. loss.backward() + scheduler.step_batch(params.batch_idx_train) optimizer.step() optimizer.zero_grad() - scheduler.step() if params.print_diagnostics and batch_idx == 5: return - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): + if (params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0): params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, @@ -686,13 +685,17 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" ) if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr) + loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) @@ -784,14 +787,7 @@ def run(rank, world_size, args): model.parameters(), lr=params.initial_lr) - # The `epoch` variable in the lambda expression picks up to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 - # here to avoid crash in constructor. - epoch = 0 - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, - lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * - (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: @@ -854,19 +850,14 @@ def remove_short_and_long_utt(c: Cut): ) for epoch in range(params.start_epoch, params.num_epochs): + scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - params.cur_epoch = epoch train_one_epoch( diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c955..c0d4b3968a 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -28,15 +28,18 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler +# use duck typing for LRScheduler since we have different possibilities, see +# our class LRScheduler. +LRSchedulerType = object + def save_checkpoint( filename: Path, model: Union[nn.Module, DDP], params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, @@ -89,7 +92,7 @@ def load_checkpoint( filename: Path, model: nn.Module, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, strict: bool = False, @@ -167,7 +170,7 @@ def save_checkpoint_with_global_batch_idx( model: Union[nn.Module, DDP], params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, From 962cf868c960125170802294c79338adec391ffa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 15:31:46 +0800 Subject: [PATCH 182/185] Fix import --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ddd2e8fb7d..62dc825b68 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -45,6 +45,7 @@ import k2 import sentencepiece as spm import torch +import optim # from . import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule From 46d52dda1080baad3f3468a2040eed962c7e73fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 12:03:41 +0800 Subject: [PATCH 183/185] Fix dir names --- .../ASR/pruned_transducer_stateless2/decode.py | 18 +++++++++--------- .../ASR/pruned_transducer_stateless2/export.py | 12 ++++++------ .../ASR/pruned_transducer_stateless2/train.py | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 8e924bf96c..38aff88340 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -124,7 +124,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 7d2a07817c..b5757ee8c4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -20,23 +20,23 @@ # to a single one using model averaging. """ Usage: -./pruned_transducer_stateless/export.py \ - --exp-dir ./pruned_transducer_stateless/exp \ +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless/decode.py`, +To use the generated file with `pruned_transducer_stateless2/decode.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless/decode.py \ - --exp-dir ./pruned_transducer_stateless/exp \ + ./pruned_transducer_stateless2/decode.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ --epoch 9999 \ --avg 1 \ --max-duration 100 \ @@ -80,7 +80,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 62dc825b68..c24fbe9a17 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -116,7 +116,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt + transducer_stateless2/exp/epoch-{start_epoch-1}.pt """, ) From d5f9d49e536d938b4ccc64bccb1a63bda4ea88fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 12:35:29 +0800 Subject: [PATCH 184/185] Modify beam search to be efficient with current joienr --- .../beam_search.py | 766 +++++++++++++++++- 1 file changed, 765 insertions(+), 1 deletion(-) mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index 227d2247c0..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py new file mode 100644 index 0000000000..5876d51586 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -0,0 +1,765 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import k2 +import torch +from model import Transducer + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + + +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner(current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False) + # logits is (1, 1, 1, vocab_size) + + y = logits.argmax().item() + if y != blank_id: + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + return hyp + + +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + encoder_out = model.joiner.encoder_proj(encoder_out) + + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1), + project_input=False) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + ans = [h[context_size:] for h in hyps] + return ans + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id: + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + return ys From 507833208868b8ed07a555437891f208274bbb3d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 14:58:15 +0800 Subject: [PATCH 185/185] Fix adding learning rate to tensorboard --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c24fbe9a17..b9ea0def6c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -695,7 +695,7 @@ def train_one_epoch( ) if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr) + tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train