From a6269d7744e4c50ac670ac1083b9456678b91547 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Tue, 2 May 2023 14:42:38 -0400 Subject: [PATCH 01/40] TDT model pull request, initial draft Signed-off-by: Hainan Xu --- .../conformer_tdt_transducer_bpe.yaml | 260 ++++++ nemo/collections/asr/losses/rnnt.py | 40 + nemo/collections/asr/losses/rnnt_pytorch.py | 117 +++ nemo/collections/asr/metrics/rnnt_wer.py | 90 ++- nemo/collections/asr/models/rnnt_models.py | 2 +- .../asr/parts/numba/rnnt_loss/__init__.py | 2 +- .../asr/parts/numba/rnnt_loss/rnnt.py | 128 +++ .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 230 +++++- .../rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 313 ++++++++ .../utils/cuda_utils/gpu_rnnt_kernel.py | 497 ++++++++++++ .../parts/submodules/rnnt_greedy_decoding.py | 754 ++++++++++++++++++ .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 32 +- .../asr/test_asr_rnnt_encdec_model.py | 51 ++ 13 files changed, 2485 insertions(+), 31 deletions(-) create mode 100644 examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml diff --git a/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml b/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml new file mode 100644 index 000000000000..d8c814d185ee --- /dev/null +++ b/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml @@ -0,0 +1,260 @@ +# It contains the default values for training an TDT Conformer-Transducer ASR model with stateless decoders, large size (~120M) with Transducer loss and sub-word encoding. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. +# +# +-------------+---------+---------+----------+--------------+--------------------------+ +# | Model | d_model | n_heads | n_layers | weight_decay | pred_hidden/joint_hidden | +# +=============+=========+========+===========+==============+==========================+ +# | Small (14M)| 176 | 4 | 16 | 0.0 | 320 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Medium (32M)| 256 | 4 | 16 | 1e-3 | 640 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Large (120M)| 512 | 8 | 17 | 1e-3 | 640 | +# +-----------------------------------------------------------+--------------------------+ +# + +# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer +# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html + +name: "TDT-Conformer-Transducer-BPE" + +model: + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 4 # must be power of 2 for striding and vggnet + subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 16 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + num_extra_outputs: 5 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # this must not be None in order to use the TDT specific decoding method. + durations: [0, 1, 2, 3, 4] + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + loss: + loss_name: "tdt_rnnt" + + tdt_rnnt_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to reduce the latency of the model for streaming + fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + durations: [0, 1, 2, 3, 4] + sigma: 0.05 + omega: 0.0 # weight for regular RNN-T loss + + # Adds Gaussian noise to the gradients of the decoder to avoid overfitting + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 500 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index ee89cb9e0f8e..8fb6909945eb 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -35,6 +35,7 @@ from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch +from nemo.collections.asr.losses.rnnt_pytorch import TDTRNNTLossPytorch from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE @@ -49,6 +50,7 @@ try: from nemo.collections.asr.parts.numba.rnnt_loss import MultiblankRNNTLossNumba, RNNTLossNumba + from nemo.collections.asr.parts.numba.rnnt_loss import TDTRNNTLossNumba NUMBA_RNNT_AVAILABLE = True except (ImportError, ModuleNotFoundError): @@ -109,6 +111,20 @@ class RNNTLossConfig: is_available=True, installation_msg="Pure Pytorch implementation of Multiblank RNN-T loss. Slow and for debugging purposes only.", ), + "tdt_rnnt": RNNTLossConfig( + loss_name="tdt_rnnt", + lib_name="numba", + min_version='0.53.0', + is_available=NUMBA_RNNT_AVAILABLE, + installation_msg=NUMBA_INSTALLATION_MESSAGE, + ), + "tdt_rnnt_pytorch": RNNTLossConfig( + loss_name="pytorch", + lib_name="torch", + min_version='0.0', + is_available=True, + installation_msg="Pure Pytorch implementation of TDT RNN-T loss. Slow and for debugging purposes only.", + ), } RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt_numba'] @@ -214,6 +230,30 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) ) _warn_unused_additional_kwargs(loss_name, loss_kwargs) + elif loss_name == 'tdt_rnnt': + fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) + clamp = loss_kwargs.pop('clamp', -1.0) + durations = loss_kwargs.pop('durations', None) + sigma = loss_kwargs.pop('sigma', 0.0) + omega = loss_kwargs.pop('omega', 0.0) + loss_func = TDTRNNTLossNumba( + blank=blank_idx, + durations=durations, + reduction='none', + fastemit_lambda=fastemit_lambda, + clamp=clamp, + sigma=sigma, + omega=omega, + ) + _warn_unused_additional_kwargs(loss_name, loss_kwargs) + + elif loss_name == 'tdt_rnnt_pytorch': + durations = loss_kwargs.pop('durations', None) + sigma = loss_kwargs.pop('sigma', 0.0) + loss_func = TDTRNNTLossPytorch(blank=blank_idx, durations=durations, reduction='none', sigma=sigma) + _warn_unused_additional_kwargs(loss_name, loss_kwargs) + + else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index ab0b5cf4f630..b20245c0ac68 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -112,6 +112,123 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): return log_prob + +class TDTRNNTLossPytorch(Loss): + @property + def input_types(self): + """Input types definitions for CTCLoss. + """ + return { + "acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), + "labels": NeuralType(('B', 'T'), LabelsType()), + "act_lens": NeuralType(tuple('B'), LengthsType()), + "label_lens": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Output types definitions for CTCLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, blank, durations, reduction, sigma): + super().__init__() + self.blank = blank + self.durations = durations + self.n_durations = len(durations) + self.reduction = reduction + self.sigma = sigma + + def forward(self, acts, labels, act_lens, label_lens): + label_acts = acts[:, :, :, : -self.n_durations] + duration_acts = acts[:, :, :, -self.n_durations :] + + label_acts = torch.log_softmax(label_acts, -1) - self.sigma + + duration_acts = torch.log_softmax(duration_acts, -1) + + forward_logprob = self.compute_forward_prob(label_acts, duration_acts, labels, act_lens, label_lens) + losses = -forward_logprob + if self.reduction == 'mean_batch': + losses = losses.mean() # global batch size average + elif self.reduction == 'mean': + losses = torch.div(losses, label_lens).mean() + elif self.reduction == 'sum': + losses = losses.sum() + elif self.reduction == 'mean_volume': + losses = losses.sum() / label_lens.sum() # same as above but longer samples weigh more + + return losses + + def logsumexp(self, a, b): + ret = torch.logsumexp(torch.stack([a, b]), dim=0) + return ret + + def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens): + B, T, U, _ = acts.shape + + log_alpha = torch.zeros(B, T, U) + log_alpha = log_alpha.cuda() + for b in range(B): + for t in range(T): + for u in range(U): + if u == 0: + if t == 0: + log_alpha[b, t, u] = 0.0 + else: + log_alpha[b, t, u] = -1000.0 + for n, l in enumerate(self.durations): + if t - l >= 0 and l > 0: # blank emission, l has to be at least 1 + tmp = ( + log_alpha[b, t - l, u] + + acts[b, t - l, u, self.blank] + + duration_acts[b, t - l, u, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + + else: + log_alpha[b, t, u] = -1000.0 + for n, l in enumerate(self.durations): + if t - l >= 0: + if l > 0: + tmp = ( + log_alpha[b, t - l, u] + + acts[b, t - l, u, self.blank] + + duration_acts[b, t - l, u, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + + tmp = ( + log_alpha[b, t - l, u - 1] + + acts[b, t - l, u - 1, labels[b, u - 1]] + + duration_acts[b, t - l, u - 1, n] + ) + log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + + log_probs = [] + for b in range(B): + tt = torch.Tensor([-1000.0]) + tt = tt.cuda() + tt = tt[0] + for n, l in enumerate(self.durations): + if act_lens[b] - l >= 0 and l > 0: + bb = ( + log_alpha[b, act_lens[b] - l, label_lens[b]] + + acts[b, act_lens[b] - l, label_lens[b], self.blank] + + duration_acts[b, act_lens[b] - l, label_lens[b], n] + ) + + tt = self.logsumexp(bb, 1.0 * tt) + + log_probs.append(tt) + + log_prob = torch.stack(log_probs) + + return log_prob + + class MultiblankRNNTLossPytorch(Loss): """ Pure Python implementation of multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 00cacbf863d4..75901fad1276 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -203,6 +203,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.blank_id = blank_id self.num_extra_outputs = joint.num_extra_outputs self.big_blank_durations = self.cfg.get("big_blank_durations", None) + self.durations = self.cfg.get("durations", None) self.compute_hypothesis_token_set = self.cfg.get("compute_hypothesis_token_set", False) self.compute_langs = decoding_cfg.get('compute_langs', False) self.preserve_alignments = self.cfg.get('preserve_alignments', None) @@ -210,6 +211,10 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.compute_timestamps = self.cfg.get('compute_timestamps', None) self.word_seperator = self.cfg.get('word_seperator', ' ') + if self.durations is not None: + assert blank_id != 0, 'blank_id must equal len(non_blank_vocabs) for multi-blank RNN-T models' + + if self.big_blank_durations is not None: if blank_id == 0: raise ValueError("blank_id must equal len(vocabs) for multi-blank RNN-T models") @@ -253,17 +258,31 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.cfg.strategy == 'greedy': if self.big_blank_durations is None: - self.decoding = greedy_decode.GreedyRNNTInfer( - decoder_model=decoder, - joint_model=joint, - blank_index=self.blank_id, - max_symbols_per_step=( - self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) - ), - preserve_alignments=self.preserve_alignments, - preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, - ) + if self.durations is not None: + self.decoding = greedy_decode.GreedyTDTRNNTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + durations=self.durations, + max_symbols_per_step=( + self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + ), + preserve_alignments=self.preserve_alignments, + preserve_frame_confidence=self.preserve_frame_confidence, + confidence_method_cfg=self.confidence_method_cfg, + ) + else: + self.decoding = greedy_decode.GreedyRNNTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + max_symbols_per_step=( + self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + ), + preserve_alignments=self.preserve_alignments, + preserve_frame_confidence=self.preserve_frame_confidence, + confidence_method_cfg=self.confidence_method_cfg, + ) else: self.decoding = greedy_decode.GreedyMultiblankRNNTInfer( decoder_model=decoder, @@ -280,17 +299,33 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): elif self.cfg.strategy == 'greedy_batch': if self.big_blank_durations is None: - self.decoding = greedy_decode.GreedyBatchedRNNTInfer( - decoder_model=decoder, - joint_model=joint, - blank_index=self.blank_id, - max_symbols_per_step=( - self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) - ), - preserve_alignments=self.preserve_alignments, - preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, - ) + if self.durations is not None: + self.decoding = greedy_decode.GreedyBatchedTDTRNNTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + durations=self.durations, + max_symbols_per_step=( + self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + ), + preserve_alignments=self.preserve_alignments, + preserve_frame_confidence=self.preserve_frame_confidence, + confidence_method_cfg=self.confidence_method_cfg, + ) + else: + self.decoding = greedy_decode.GreedyBatchedRNNTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + max_symbols_per_step=( + self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + ), + preserve_alignments=self.preserve_alignments, + preserve_frame_confidence=self.preserve_frame_confidence, + confidence_method_cfg=self.confidence_method_cfg, + ) + + else: self.decoding = greedy_decode.GreedyBatchedMultiblankRNNTInfer( decoder_model=decoder, @@ -481,10 +516,13 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp # RNN-T sample level is already preprocessed by implicit RNNT decoding # Simply remove any blank and possibly big blank tokens if self.blank_id != 0: - num_extra_outputs = 0 - if self.big_blank_durations is not None: - num_extra_outputs += len(self.big_blank_durations) + # use < here works both for standard and multi-blank RNN-T. + prediction = [p for p in prediction if p < self.blank_id] + elif self.blank_id != 0 and self.big_blank_durations is not None: + num_extra_outputs = len(self.big_blank_durations) prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs] + + # Simply remove any blank and possibly big blank tokens else: prediction = [p for p in prediction if p != self.blank_id] @@ -1084,7 +1122,7 @@ def decode_tokens_to_str(self, tokens: List[int]) -> str: Args: tokens: List of int representing the token ids. - Returns: + eturns: A decoded string. """ hypothesis = ''.join(self.decode_ids_to_tokens(tokens)) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index f4e227f510af..9be9b9d9eb85 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -72,7 +72,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) self.loss = RNNTLoss( - num_classes=self.joint.num_classes_with_blank - 1, + num_classes=self.joint.num_classes_with_blank - 1 - self.joint.num_extra_outputs, loss_name=loss_name, loss_kwargs=loss_kwargs, reduction=self.cfg.get("rnnt_reduction", "mean_batch"), diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py b/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py index 66e30c77590a..eecfaf785b3e 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from nemo.collections.asr.parts.numba.rnnt_loss.rnnt import rnnt_loss_cpu, rnnt_loss_gpu -from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import MultiblankRNNTLossNumba, RNNTLossNumba +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import MultiblankRNNTLossNumba, RNNTLossNumba, TDTRNNTLossNumba diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index 64c8955006ed..20ac579d3196 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -236,6 +236,134 @@ def rnnt_loss_gpu( return True +def tdt_rnnt_loss_gpu( + label_acts: torch.Tensor, + duration_acts: torch.Tensor, + labels: torch.Tensor, + input_lengths: torch.Tensor, + label_lengths: torch.Tensor, + costs: torch.Tensor, + label_grads: torch.Tensor, + duration_grads: torch.Tensor, + blank_label: int, + durations: list, + fastemit_lambda: float, + clamp: float, + num_threads: int, + sigma: float, + omega: float, +): + """ + Wrapper method for accessing GPU Multi-blank RNNT loss (https://arxiv.org/pdf/2211.03541.pdf). + + CUDA implementation ported from [HawkAaron/warp-transducer](https://github.com/HawkAaron/warp-transducer). + + Args: + acts: Activation tensor of shape [B, T, U, V+1]. + labels: Ground truth labels of shape [B, U]. + input_lengths: Lengths of the acoustic sequence as a vector of ints [B]. + label_lengths: Lengths of the target sequence as a vector of ints [B]. + costs: Zero vector of length [B] in which costs will be set. + grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set. + blank_label: Index of the standard blank token in the vocabulary. + durations: A list of supported durations for big blank symbols + in the model, e.g. [2, 4, 8]. Note we only include durations for ``big + blanks'' here and it should not include 1 for the standard blank. + Those big blanks have vocabulary indices after the standard blank index. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. + num_threads: Number of threads for OpenMP. + sigma: logit-undernormalization weight used in the multi-blank model. Refer to + the multi-blank paper https://arxiv.org/pdf/2211.03541 for detailed explanations. + omega: weight for regular RNN-T loss + """ + minibatch_size = label_acts.shape[0] + maxT = label_acts.shape[1] + maxU = label_acts.shape[2] + alphabet_size = label_acts.shape[3] + + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(label_acts.device).cuda_stream) + else: + stream = cuda.default_stream() + + if num_threads < 0: + num_threads = multiprocessing.cpu_count() + + num_threads = max(1, num_threads) # have to use at least 1 thread + + gpu_size, status = rnnt_helper.get_workspace_size(maxT, maxU, minibatch_size, gpu=True) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Invalid parameter passed when calculating working space memory") + + # Select GPU index + cuda.select_device(label_acts.device.index) + gpu_workspace = torch.zeros(gpu_size, device=label_acts.device, dtype=label_acts.dtype, requires_grad=False) + + tdt_workspace = torch.zeros(len(durations), device=label_acts.device, dtype=torch.long, requires_grad=False) + + for i in range(0, len(durations)): + tdt_workspace[i] = durations[i] + + ### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ### + label_acts, label_acts_shape = rnnt_helper.flatten_tensor(label_acts) + duration_acts, duration_acts_shape = rnnt_helper.flatten_tensor(duration_acts) + + wrapper = gpu_rnnt.TDTGPURNNT( + minibatch=minibatch_size, + maxT=maxT, + maxU=maxU, + alphabet_size=alphabet_size, + workspace=gpu_workspace, + tdt_workspace=tdt_workspace, + num_durations=len(durations), + blank=blank_label, + fastemit_lambda=fastemit_lambda, + clamp=clamp, + num_threads=num_threads, + stream=stream, + sigma=sigma, + omega=omega, + ) + + if label_grads is None: + status = wrapper.score_forward( + label_acts=label_acts.data, + duration_acts=duration_acts.data, + costs=costs.data, + pad_labels=labels.data, + label_lengths=label_lengths.data, + input_lengths=input_lengths.data, + ) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Could not calculate forward scores") + + else: + ### FLATTEN GRAD TENSOR ### + label_grads, label_grads_shape = rnnt_helper.flatten_tensor(label_grads) + duration_grads, duration_grads_shape = rnnt_helper.flatten_tensor(duration_grads) + + status = wrapper.cost_and_grad( + label_acts=label_acts.data, + duration_acts=duration_acts.data, + label_grads=label_grads.data, + duration_grads=duration_grads.data, + costs=costs.data, + pad_labels=labels.data, + label_lengths=label_lengths.data, + input_lengths=input_lengths.data, + ) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Could not calculate forward scores") + + del gpu_workspace, tdt_workspace, wrapper + return True + + def multiblank_rnnt_loss_gpu( acts: torch.Tensor, labels: torch.Tensor, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 3ed9b82bf996..dd92db87c879 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -34,7 +34,7 @@ from nemo.collections.asr.parts.numba.rnnt_loss import rnnt from nemo.collections.asr.parts.numba.rnnt_loss.utils.cpu_utils import cpu_rnnt -__all__ = ['rnnt_loss', 'RNNTLossNumba', 'MultiblankRNNTLossNumba'] +__all__ = ['rnnt_loss', 'RNNTLossNumba', 'MultiblankRNNTLossNumba', 'TDTRNNTLossNumba'] class _RNNTNumba(Function): @@ -91,6 +91,106 @@ def backward(ctx, grad_output): return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None +class _TDTRNNTNumba(Function): + """ + Numba class for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf) + """ + + @staticmethod + def forward( + ctx, + label_acts, + duration_acts, + labels, + act_lens, + label_lens, + blank, + durations, + reduction, + fastemit_lambda, + clamp, + sigma, + omega, + ): + """ + durations: list of durations for multi-blank transducer, e.g. + [2, 4, 8]. + sigma: hyper-parameter for logit under-normalization method for training + multi-blank transducers. Recommended value 0.05. + omega: weight for standard RNN-T loss + Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for + the above parameters; + For other parameters for this class, refer to comment for class _RNNTNumba + """ + is_cuda = label_acts.is_cuda + + certify_inputs(label_acts, labels, act_lens, label_lens) + if clamp < 0: + raise ValueError("`clamp` must be 0.0 or positive float value.") + + if is_cuda: + loss_func = rnnt.tdt_rnnt_loss_gpu + else: + exit(-1) + label_grads = torch.zeros_like(label_acts) if label_acts.requires_grad else None + duration_grads = torch.zeros_like(duration_acts) if duration_acts.requires_grad else None + minibatch_size = label_acts.size(0) + costs = torch.zeros(minibatch_size, device=label_acts.device, dtype=label_acts.dtype) + + loss_func( + label_acts, + duration_acts, + labels=labels, + input_lengths=act_lens, + label_lengths=label_lens, + costs=costs, + label_grads=label_grads, + duration_grads=duration_grads, + blank_label=blank, + durations=durations, + fastemit_lambda=fastemit_lambda, + clamp=clamp, + sigma=sigma, + omega=omega, + num_threads=0, + ) + + if reduction in ['sum', 'mean']: + costs = costs.sum().unsqueeze_(-1) + if reduction == 'mean': + costs /= minibatch_size + + if label_grads is not None: + label_grads /= minibatch_size + duration_grads /= minibatch_size + + ctx.label_grads = label_grads + ctx.duration_grads = duration_grads + + return costs + + @staticmethod + def backward(ctx, grad_output): + if grad_output is not None and ctx.label_grads is not None: + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.label_grads) + return ( + ctx.label_grads.mul_(grad_output), + ctx.duration_grads.mul_(grad_output), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + + + class _MultiblankRNNTNumba(Function): """ Numba class for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf) @@ -237,6 +337,54 @@ def multiblank_rnnt_loss( ) +def tdt_rnnt_loss( + acts, + labels, + act_lens, + label_lens, + blank, + durations=[], + reduction='mean', + fastemit_lambda: float = 0.0, + clamp: float = 0.0, +): + """ + TDT RNN Transducer (https://arxiv.org/pdf/2211.03541.pdf) Loss (functional form) + Args: + acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + blank (int): standard blank label. + durations: list of durations for multi-blank transducer, e.g. + [2, 4, 8]. + sigma: hyper-parameter for logit under-normalization method for training + multi-blank transducers. Recommended value 0.05. + Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for + the last two params. + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, + 'mean': the output losses will be divided by the target lengths and + then the mean over the batch is taken. Default: 'mean' + """ + if not acts.is_cuda: + # Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping + # *after* we have obtained the gradients of loss(logsoftmax()). + # This is highly wasteful since it requires a copy of the entire joint tensor which is expensive. + # CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore + # can inplace clamp the gradient. + if clamp > 0.0: + acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp) + + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. + acts = torch.nn.functional.log_softmax(acts, -1) + + return _TDTRNNTNumba.apply(acts, labels, act_lens, label_lens, blank, durations, reduction, fastemit_lambda, clamp) + + + + class RNNTLossNumba(Module): """ Parameters: @@ -353,6 +501,86 @@ def forward(self, acts, labels, act_lens, label_lens): self.sigma, ) +class TDTRNNTLossNumba(Module): + """ + Parameters: + blank (int): standard blank label. + durations: list of durations for multi-blank transducer, e.g. + [2, 4, 8]. + sigma: hyper-parameter for logit under-normalization method for training + multi-blank transducers. Recommended value 0.05. + Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for + the above parameters; + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, + 'mean': the output losses will be divided by the target lengths and + then the mean over the batch is taken. Default: 'mean' + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. + """ + + def __init__( + self, + blank, + durations=[], + reduction='mean', + fastemit_lambda: float = 0.0, + clamp: float = -1, + sigma: float = 0.0, + omega: float = 0.0, + ): + super(TDTRNNTLossNumba, self).__init__() + self.blank = blank + self.durations = durations + self.fastemit_lambda = fastemit_lambda + self.clamp = float(clamp) if clamp > 0 else 0.0 + self.reduction = reduction + self.loss = _TDTRNNTNumba.apply + self.sigma = sigma + self.omega = omega + + def forward(self, acts, labels, act_lens, label_lens): + """ + log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + """ + + label_acts = acts[:, :, :, : -len(self.durations)].contiguous() + duration_acts = torch.nn.functional.log_softmax(acts[:, :, :, -len(self.durations) :], dim=-1).contiguous() + + # if not acts.is_cuda: + # # Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping + # # *after* we have obtained the gradients of loss(logsoftmax()). + # # This is highly wasteful since it requires a copy of the entire joint tensor which is expensive. + # # CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore + # # can inplace clamp the gradient. + # if self.clamp > 0.0: + # acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp) + # + # # NOTE manually done log_softmax for CPU version, + # # log_softmax is computed within GPU version. + # acts = torch.nn.functional.log_softmax(acts, -1) + + return self.loss( + label_acts, + duration_acts, + labels, + act_lens, + label_lens, + self.blank, + self.durations, + self.reduction, + self.fastemit_lambda, + self.clamp, + self.sigma, + self.omega, + ) + + + def check_type(var, t, name): if var.dtype is not t: diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index dca4e732c062..c10fa9fba030 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -27,6 +27,7 @@ # limitations under the License. import multiprocessing +import random from typing import Optional, Tuple import numba @@ -520,3 +521,315 @@ def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): bigblank_durations = self.big_blank_workspace[: self.num_big_blanks] return used_offset, (denom, alphas, betas, llForward, llBackward, bigblank_durations) + + +class TDTGPURNNT(GPURNNT): + def __init__( + self, + sigma: float, + omega: float, + num_durations: int, + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + workspace, + tdt_workspace, + blank: int, + fastemit_lambda: float, + clamp: float, + num_threads: int, + stream, + ): + """ + Helper class to launch the CUDA Kernels to compute TDT Transducer Loss (https://arxiv.org/pdf/2211.03541). + + Args: + sigma: Hyper-parameter related to the logit-normalization method in training tdt transducers. + omega: Hyper-parameter related to the sampled training. + num_durations: Number of big blank symbols the model has. This should not include the standard blank symbol. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V + 1 + num-big-blanks + workspace: An allocated chunk of memory that will be sliced off and reshaped into required + blocks used as working memory. + tdt_workspace: An allocated chunk of memory that will be sliced off and reshaped into required + blocks used as working memory specifically for the tdt related computations. + blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. + num_threads: Number of OMP threads to launch. + stream: Numba Cuda Stream. + """ + super().__init__( + minibatch, maxT, maxU, alphabet_size, workspace, blank, fastemit_lambda, clamp, num_threads, stream + ) + self.tdt_workspace = cuda.as_cuda_array( + tdt_workspace + ) # a flat vector of integer numbers that represents allocated memory slices + + self.num_durations = num_durations + self.sigma = sigma + self.omega = omega + + def compute_cost_and_score( + self, + label_acts: torch.Tensor, + duration_acts: torch.Tensor, + label_grads: Optional[torch.Tensor], + duration_grads: Optional[torch.Tensor], + costs: torch.Tensor, + labels: torch.Tensor, + label_lengths: torch.Tensor, + input_lengths: torch.Tensor, + ) -> global_constants.RNNTStatus: + """ + Compute both the loss and the gradients. + + Args: + acts: A flattened tensor of shape [B, T, U, V+1] representing the activation matrix. + grad: A flattented zero tensor of same shape as acts. + costs: A zero vector of length B which will be updated inplace with the log probability costs. + flat_labels: A flattened matrix of labels of shape [B, U] + label_lengths: A vector of length B that contains the original lengths of the acoustic sequence. + input_lengths: A vector of length B that contains the original lengths of the target sequence. + + Updates: + This will launch kernels that will update inline the following variables: + - grads: Gradients of the activation matrix wrt the costs vector. + - costs: Negative log likelihood of the forward variable. + + Returns: + An enum that either represents a successful RNNT operation or failure. + """ + training = label_grads is not None + + if training: + label_grads *= 0.0 # zero grads + duration_grads *= 0.0 # zero grads + + _, (denom, alphas, betas, llForward, llBackward, durations) = self._prepare_workspace() + + ######## START EXECUTION ######## + self.log_softmax(label_acts, denom) + + r = random.uniform(0, 1) + if r < self.omega: + # Compute alphas + gpu_rnnt_kernel.compute_alphas_kernel[self.minibatch_, self.maxU_, self.stream_, 0]( + label_acts, + denom, + alphas, + llForward, + input_lengths, + label_lengths, + labels, + self.minibatch_, + self.maxT_, + self.maxU_, + self.alphabet_size_, + self.blank_, + ) + else: + # Compute alphas + gpu_rnnt_kernel.compute_tdt_alphas_kernel[self.minibatch_, self.maxU_, self.stream_, 0]( + label_acts, + duration_acts, + denom, + self.sigma, + alphas, + llForward, + input_lengths, + label_lengths, + labels, + self.minibatch_, + self.maxT_, + self.maxU_, + self.alphabet_size_, + self.blank_, + durations, + self.num_durations, + ) + + if training: + # Compute betas + if r < self.omega: + gpu_rnnt_kernel.compute_betas_kernel[self.minibatch_, self.maxU_, self.stream_, 0]( + label_acts, + denom, + betas, + llBackward, + input_lengths, + label_lengths, + labels, + self.minibatch_, + self.maxT_, + self.maxU_, + self.alphabet_size_, + self.blank_, + ) + + # Compute gradient + grad_blocks_per_grid = self.minibatch_ * self.maxT_ * self.maxU_ + grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE + gpu_rnnt_kernel.compute_grad_kernel[grad_blocks_per_grid, grad_threads_per_block, self.stream_, 0]( + label_grads, + label_acts, + denom, + alphas, + betas, + llForward, + input_lengths, + label_lengths, + labels, + self.minibatch_, + self.maxT_, + self.maxU_, + self.alphabet_size_, + self.blank_, + self.fastemit_lambda_, + self.clamp_, + ) + else: + gpu_rnnt_kernel.compute_tdt_betas_kernel[self.minibatch_, self.maxU_, self.stream_, 0]( + label_acts, + duration_acts, + denom, + self.sigma, + betas, + llBackward, + input_lengths, + label_lengths, + labels, + self.minibatch_, + self.maxT_, + self.maxU_, + self.alphabet_size_, + self.blank_, + durations, + self.num_durations, + ) + + # Compute gradient + grad_blocks_per_grid = self.minibatch_ * self.maxT_ * self.maxU_ + grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE + gpu_rnnt_kernel.compute_tdt_grad_kernel[grad_blocks_per_grid, grad_threads_per_block, self.stream_, 0]( + label_grads, + duration_grads, + label_acts, + duration_acts, + denom, + self.sigma, + alphas, + betas, + llForward, + input_lengths, + label_lengths, + labels, + self.minibatch_, + self.maxT_, + self.maxU_, + self.alphabet_size_, + self.blank_, + durations, + self.num_durations, + self.fastemit_lambda_, + self.clamp_, + ) + + # // cost copy, negate (for log likelihood) and update with additional regularizers + # This needs to be done via CUDA, because we used temporary memory llForward + # passed to alpha, which was updated with log likelihoods. + # But copying this data into a pytorch pointer is more difficult (numba api is one way) + # Therefore launch a pointwise CUDA kernel to update the costs inplace from data of llForward + # Then negate to compute the loglikelihood. + threadsperblock = min(costs.shape[0], 32) + blockspergrid = (costs.shape[0] + (threadsperblock - 1)) // threadsperblock + rnnt_helper.compute_costs_data[blockspergrid, threadsperblock, self.stream_, 0]( + llForward, costs, self.fastemit_lambda_ + ) + self.stream_.synchronize() + + return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS + + def cost_and_grad( + self, + label_acts: torch.Tensor, + duration_acts: torch.Tensor, + label_grads: torch.Tensor, + duration_grads: torch.Tensor, + costs: torch.Tensor, + pad_labels: torch.Tensor, + label_lengths: torch.Tensor, + input_lengths: torch.Tensor, + ): + if ( + duration_acts is None + or label_acts is None + or label_grads is None + or duration_grads is None + or costs is None + or pad_labels is None + or label_lengths is None + or input_lengths is None + ): + return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE + + return self.compute_cost_and_score( + label_acts, duration_acts, label_grads, duration_grads, costs, pad_labels, label_lengths, input_lengths + ) + + def score_forward( + self, + label_acts: torch.Tensor, + duration_acts: torch.Tensor, + costs: torch.Tensor, + pad_labels: torch.Tensor, + label_lengths: torch.Tensor, + input_lengths: torch.Tensor, + ): + if ( + label_acts is None + or duration_acts is None + or costs is None + or pad_labels is None + or label_lengths is None + or input_lengths is None + ): + return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE + + return self.compute_cost_and_score( + label_acts, duration_acts, None, costs, pad_labels, label_lengths, input_lengths + ) + + def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): + """ + Helper method that uses the workspace and constructs slices of it that can be used. + + Returns: + An int, representing the offset of the used workspace (practically, the slice of the workspace consumed) + A tuple of tensors representing the shared workspace. + """ + used_offset = 0 + + # // denom + denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] + used_offset += self.maxT_ * self.maxU_ * self.minibatch_ + + # // alphas & betas + alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] + used_offset += self.maxT_ * self.maxU_ * self.minibatch_ + betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] + used_offset += self.maxT_ * self.maxU_ * self.minibatch_ + + # // logllh + llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] + used_offset += self.minibatch_ + llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] + used_offset += self.minibatch_ + + durations = self.tdt_workspace[: self.num_durations] + + return used_offset, (denom, alphas, betas, llForward, llBackward, durations) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index dbeb1544e7e3..36cbb94a0eb6 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -35,6 +35,8 @@ GPU_RNNT_THREAD_SIZE = 256 +INF = 99999.9 + @cuda.jit(device=True, inline=True) def logp( @@ -875,3 +877,498 @@ def compute_multiblank_grad_kernel( # update internal index through the thread_buffer; # until idx < V + 1, such that entire vocabulary has been updated. idx += GPU_RNNT_THREAD_SIZE + + + +@cuda.jit(device=True, inline=True) +def logp_duration(acts: torch.Tensor, maxT: int, maxU: int, num_durations: int, mb: int, t: int, u: int, v: int): + col = (mb * maxT + t) * maxU + u + return acts[col * num_durations + v] + + +@cuda.jit() +def compute_tdt_alphas_kernel( + acts: torch.Tensor, + duration_acts: torch.Tensor, + denom: torch.Tensor, + sigma: float, + alphas: torch.Tensor, + llForward: torch.Tensor, + xlen: torch.Tensor, + ylen: torch.Tensor, + mlabels: torch.Tensor, # [B] + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + blank_: int, + durations: torch.Tensor, + num_durations: int, +): + """ + Compute alpha (forward variable) probabilities over the transduction step. + + Args: + acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor + across entire vocabulary. + alphas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable + probabilities. + llForward: Zero tensor of shape [B]. Represents the log-likelihood of the forward pass. + Returned as the forward pass loss that is reduced by the optimizer. + xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded + activation tensor. + ylen: Vector of length B which contains the actual target sequence lengths in the padded + activation tensor. + mlabels: Matrix of shape [B, U+1] (+1 here is due to token - usually the RNNT blank). + The matrix contains the padded target transcription that must be predicted. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). + blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + big_blank_: Index of the RNNT big blank token in the vocabulary. Generally the first or last token in the vocab. + + Updates: + Kernel inplace updates the following inputs: + - alphas: forward variable scores. + - llForward: log-likelihood of forward variable. + """ + # // launch B blocks, each block has U threads + b = cuda.blockIdx.x # // batch id + u = cuda.threadIdx.x # label id, u + T = xlen[b] # select AM length of current sample + U = ylen[b] + 1 # select target length of current sample, +1 for the blank token + + labels: torch.Tensor = mlabels[b] # mb label start point, equivalent to mlabels + b * (maxU - 1) + offset = b * maxT * maxU # pointer indexing offset + + # alphas += offset # pointer offset, ignored since we explicitly add offset + + # Initilize alpha[b, t=0, u=0] for all b in B + if u == 0: + alphas[offset] = 0 + + # sync until all alphas are initialized + cuda.syncthreads() + + # Ordinary alpha calculations, broadcast across B=b and U=u + # Look up forward variable calculation from rnnt_numpy.forward_pass() + for n in range(1, T + U - 1): + t = n - u + + if u == 0: + # for t in range(1, T) step to initialize alphas[b, t, 0] + if t > 0 and t < T: + alphas[offset + t * maxU + u] = -INF + + for i in range(1, num_durations): # skip 0 since blank emission has to advance by at least one + if t >= durations[i]: + alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( + alphas[offset + t * maxU + u], + alphas[offset + (t - durations[i]) * maxU + u] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t - durations[i], u, i), + ) + else: + break # durations are in ascending order + + elif u < U: + # for u in range(1, U) step to initialize alphas[b, 0, u] + if t == 0: + alphas[offset + u] = ( + alphas[offset + u - 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1]) + - sigma + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t, u - 1, 0 + ) # t = 0 so it must be duration = 0 only when duration_id = 0 + ) + + # for t in range(1, T) for u in range(1, U) step to compute alphas[b, t, u] + elif t > 0 and t < T: + no_emit = -INF + for i in range(1, num_durations): + if t >= durations[i]: + no_emit = rnnt_helper.log_sum_exp( + no_emit, + alphas[offset + (t - durations[i]) * maxU + u] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t - durations[i], u, i), + ) + else: + break + + emit = -INF + for i in range(0, num_durations): + if t >= durations[i]: + emit = rnnt_helper.log_sum_exp( + emit, + alphas[offset + (t - durations[i]) * maxU + u - 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u - 1, labels[u - 1]) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t - durations[i], u - 1, i), + ) + else: + break + + alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) + + # sync across all B=b and U=u + cuda.syncthreads() + + # After final sync, alphas[b, T-1, U - 1] + logprobs[b, T-1, U-1, blank] + denom[b, T-1, U-1] gives + # log-likelihood of forward pass. + if u == 0: + loglike = ( + alphas[offset + (T - 1) * maxU + U - 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 1) + ) + + for i in range(2, num_durations): + if T >= durations[i]: + big_blank_loglike = ( + alphas[offset + (T - durations[i]) * maxU + U - 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, T - durations[i], U - 1, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - durations[i], U - 1, i) + ) + loglike = rnnt_helper.log_sum_exp(loglike, big_blank_loglike) + else: + break + + llForward[b] = loglike + + +@cuda.jit() +def compute_tdt_betas_kernel( + acts: torch.Tensor, + duration_acts: torch.Tensor, + denom: torch.Tensor, + sigma: float, + betas: torch.Tensor, + llBackward: torch.Tensor, + xlen: torch.Tensor, + ylen: torch.Tensor, + mlabels: torch.Tensor, # [B, U] + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + blank_: int, + durations: torch.Tensor, + num_durations: int, +): + """ + Compute beta (backward variable) probabilities over the transduction step. + + Args: + acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor + across entire vocabulary. + betas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the backward variable + probabilities. + llBackward: Zero tensor of shape [B]. Represents the log-likelihood of the backward pass. + Returned as the backward pass loss that is reduced by the optimizer. + xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded + activation tensor. + ylen: Vector of length B which contains the actual target sequence lengths in the padded + activation tensor. + mlabels: Matrix of shape [B, U+1] (+1 here is due to token - usually the RNNT blank). + The matrix contains the padded target transcription that must be predicted. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). + blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + + Updates: + Kernel inplace updates the following inputs: + - betas: backward variable scores. + - llBackward: log-likelihood of backward variable. + """ + # // launch B blocks, each block has U threads + b = cuda.blockIdx.x # // batch id + u = cuda.threadIdx.x # label id, u + T = xlen[b] # select AM length of current sample + U = ylen[b] + 1 # select target length of current sample, +1 for the blank token + + labels: torch.Tensor = mlabels[b] # mb label start point, equivalent to mlabels + b * (maxU - 1) + offset = b * maxT * maxU # pointer indexing offset + + # betas += offset # pointer offset, ignored since we explicitly add offset + + # Initilize beta[b, t=T-1, u=U-1] for all b in B with log_probs[b, t=T-1, u=U-1, blank] + if u == 0: + betas[offset + (T - 1) * maxU + U - 1] = ( + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 1) + ) + + # sync until all betas are initialized + cuda.syncthreads() + + # Ordinary beta calculations, broadcast across B=b and U=u + # Look up backward variable calculation from rnnt_numpy.backward_pass() + for n in range(T + U - 2, -1, -1): + t = n - u + + if u == (U - 1): + # for t in reversed(range(T - 1)) step to initialize betas[b, t, U-1] + if t >= 0 and t + 1 < T: + betas[offset + t * maxU + U - 1] = -INF + for i in range(1, num_durations): + if t + durations[i] < T: # recursive beta computation. + betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( + betas[offset + t * maxU + U - 1], + betas[offset + (t + durations[i]) * maxU + U - 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, U - 1, i) + - sigma, + ) + elif t + durations[i] == T: # beta base case + betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( + betas[offset + t * maxU + U - 1], + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, U - 1, i) + - sigma, + ) + + elif u < U: + if t == T - 1: + # for u in reversed(range(U - 1)) step to initialize betas[b, T-1, u] + betas[offset + (T - 1) * maxU + u] = ( + betas[offset + (T - 1) * maxU + u + 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, u, 0) + - sigma + ) + + elif (t >= 0) and (t < T - 1): + # for t in reversed(range(T - 1)) for u in reversed(range(U - 1)) step to compute betas[b, t, u] + no_emit = -INF + for i in range(1, num_durations): + if t + durations[i] < T: + no_emit = rnnt_helper.log_sum_exp( + no_emit, + betas[offset + (t + durations[i]) * maxU + u] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t, u, blank_) + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, u, i) + - sigma, + ) + + emit = -INF + for i in range(0, num_durations): + if t + durations[i] < T: + emit = rnnt_helper.log_sum_exp( + emit, + betas[offset + (t + durations[i]) * maxU + u + 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, t, u, labels[u]) + + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, u, i) + - sigma, + ) + + betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) + + # sync across all B=b and U=u + cuda.syncthreads() + + # After final sync, betas[b, 0, 0] gives + # log-likelihood of backward pass. + if u == 0: + llBackward[b] = betas[offset] + + +@cuda.jit() +def compute_tdt_grad_kernel( + label_grads: torch.Tensor, + duration_grads: torch.Tensor, + acts: torch.Tensor, + duration_acts: torch.Tensor, + denom: torch.Tensor, + sigma: float, + alphas: torch.Tensor, + betas: torch.Tensor, + logll: torch.Tensor, + xlen: torch.Tensor, + ylen: torch.Tensor, + mlabels: torch.Tensor, # [B, U] + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + blank_: int, + durations: torch.Tensor, + num_durations: int, + fastemit_lambda: float, + clamp: float, +): + """ + Compute gradients over the transduction step. + + Args: + grads: Zero Tensor of shape [B, T, U, V+1]. Is updated by this kernel to contain the gradients + of this batch of samples. + acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor + across entire vocabulary. + alphas: Alpha variable, contains forward probabilities. A tensor of shape [B, T, U]. + betas: Beta varoable, contains backward probabilities. A tensor of shape [B, T, U]. + logll: Log-likelihood of the forward variable, represented as a vector of shape [B]. + Represents the log-likelihood of the forward pass. + xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded + activation tensor. + ylen: Vector of length B which contains the actual target sequence lengths in the padded + activation tensor. + mlabels: Matrix of shape [B, U+1] (+1 here is due to token - usually the RNNT blank). + The matrix contains the padded target transcription that must be predicted. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). + blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. + + Updates: + Kernel inplace updates the following inputs: + - grads: Gradients with respect to the log likelihood (logll). + """ + # Kernel call: + # blocks_per_grid = minibatch (b) * maxT (t) * maxU (u) + # threads_per_block = constant buffer size of parallel threads (v :: Constant) + tid = cuda.threadIdx.x # represents v, taking steps of some constant size + idx = tid # index of v < V+1; in steps of constant buffer size + col = cuda.blockIdx.x # represents a fused index of b * t * u + + # Decompose original indices from fused `col` + u = col % maxU # (b * t * u) % u = u + bt = (col - u) // maxU # (b * t * u - u) // U = b * t + t = bt % maxT # (b * t) % t = t + mb = (bt - t) // maxT # (b * t - t) // T = b + + # constants + T = xlen[mb] # select AM length of current sample + U = ylen[mb] + 1 # select target length of current sample, +1 for the blank token + labels: torch.Tensor = mlabels[mb] # labels = mlabels + mb * (maxU - 1); + + # Buffered gradient calculations, broadcast across B=b, T=t and U=u, looped over V with some constant stride. + # Look up gradient calculation from rnnt_numpy.compute_gradient() + + if t < T and u < U: + logpk_blank = denom[col] + acts[col * alphabet_size + blank_] - sigma + + if idx < num_durations: + grad = 0.0 + if t + durations[idx] < T and u < U - 1: # for label + logpk_label = denom[col] + acts[col * alphabet_size + labels[u]] - sigma + grad -= math.exp(alphas[col] + betas[col + 1 + durations[idx] * maxU] + logpk_label - logll[mb]) + + if t + durations[idx] < T and idx > 0: # for blank in the middle + grad -= math.exp(alphas[col] + betas[col + durations[idx] * maxU] + logpk_blank - logll[mb]) + + if t + durations[idx] == T and idx >= 1 and u == U - 1: # for blank as the last symbol + grad -= math.exp(alphas[col] + logpk_blank - logll[mb]) + + grad = grad * math.exp(duration_acts[col * num_durations + idx]) + duration_grads[col * num_durations + idx] = grad + + # For cuda kernels, maximum number of threads per block is limited to some value. + # However, it may be the case that vocabulary size is larger than this limit + # To work around this, an arbitrary thread buffer size is chosen such that, + # 1) each element within the thread pool operates independently of the other + # 2) An inner while loop moves the index of each buffer element by the size of the buffer itself, + # such that all elements of the vocabulary size are covered in (V + 1 // thread_buffer) number of steps. + # As such, each thread will perform the while loop at least (V + 1 // thread_buffer) number of times + while idx < alphabet_size: + # remember, `col` represents the tri-index [b, t, u] + # therefore; logpk = denom[b, t, u] + acts[b, t, u, v] + logpk = denom[col] + acts[col * alphabet_size + idx] # - sigma + # initialize the grad of the sample acts[b, t, u, v] + grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) # * math.exp(sigma) + + # If FastEmit regularization is enabled, calculate the gradeint of probability of predicting the next label + # at the current timestep. + # The formula for this is Equation 9 in https://arxiv.org/abs/2010.11148, multiplied by the log probability + # of the current step (t, u), normalized by the total log likelihood. + # Once the gradient has been calculated, scale it by `fastemit_lambda`, as in Equation 10. + if fastemit_lambda > 0.0 and u < U - 1: + fastemit_grad = 0.0 + + for i in range(0, num_durations): + if t + durations[i] < T: + fastemit_grad += fastemit_lambda * math.exp( + alphas[col] # alphas(t, u) + + (denom[col] + acts[col * alphabet_size + labels[u]]) + + duration_acts[col * num_durations + i] + + betas[col + 1 + durations[i] * maxU] # betas(t, u+1) + + logpk # log Pr(k|t, u) + - sigma # y_hat(t, u) + - logll[mb] # total log likelihood for normalization + ) + else: + fastemit_grad = 0.0 + + # Update the gradient of act[b, t, u, v] with the gradient from FastEmit regularization + grad = grad + fastemit_grad + + # // grad to last blank transition + # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) + if (idx == blank_) and (u == U - 1): + for i in range(1, num_durations): + if t == T - durations[i]: + grad -= math.exp( + alphas[col] + logpk - sigma - logll[mb] + duration_acts[col * num_durations + i] + ) + + # grad of blank across t < T; + # grad[b, t 0.0: + g = label_grads[col * (alphabet_size) + idx] + g = min(g, clamp) + g = max(g, -clamp) + label_grads[col * (alphabet_size) + idx] = g + + # update internal index through the thread_buffer; + # until idx < V + 1, such that entire vocabulary has been updated. + idx += GPU_RNNT_THREAD_SIZE diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 5e98b03f2fe2..ad3b010f29be 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2202,3 +2202,757 @@ class GreedyBatchedRNNTInferConfig: preserve_alignments: bool = False preserve_frame_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = None + + +class GreedyTDTRNNTInfer(_GreedyRNNTInfer): + """A greedy transducer decoder. + + Sequence level greedy decoding, performed auto-repressively. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Must be len(vocabulary) for multi=blank RNNTs. + durations: a list containing durations for TDT. + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + preserve_alignments: Bool flag which preserves the history of alignments generated during + greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of + Tuple(Tensor (of length V + 1 + num-big-blanks), Tensor(scalar, label after argmax)). + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. + U is the number of target tokens for the current timestep Ti. + preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated + during greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of List of floats. + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. + U is the number of target tokens for the current timestep Ti. + confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence scores. + name: The method name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using normalized entropy of a log-likelihood vector. + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the temperature should comply the following inequality: + 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renui' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the temperature equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. + """ + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + durations: list, + max_symbols_per_step: Optional[int] = None, + preserve_alignments: bool = False, + preserve_frame_confidence: bool = False, + confidence_method_cfg: Optional[DictConfig] = None, + ): + super().__init__( + decoder_model=decoder_model, + joint_model=joint_model, + blank_index=blank_index, + max_symbols_per_step=max_symbols_per_step, + preserve_alignments=preserve_alignments, + preserve_frame_confidence=preserve_frame_confidence, + confidence_method_cfg=confidence_method_cfg, + ) + self.durations = durations + + @typecheck() + def forward( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + Args: + encoder_output: A tensor of size (batch, features, timesteps). + encoded_lengths: list of int representing the length of each sequence + output sequence. + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.inference_mode(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + self.decoder.eval() + self.joint.eval() + + hypotheses = [] + # Process each sequence independently + with self.decoder.as_frozen(), self.joint.as_frozen(): + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] + + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) + + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, encoded_lengths) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (packed_result,) + + @torch.no_grad() + def _greedy_decode( + self, x: torch.Tensor, out_len: torch.Tensor, partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None + ): + # x: [T, 1, D] + # out_len: [seq_len] + + # Initialize blank state and empty label set in Hypothesis + hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) + + if partial_hypotheses is not None: + hypothesis.last_token = partial_hypotheses.last_token + hypothesis.y_sequence = ( + partial_hypotheses.y_sequence.cpu().tolist() + if isinstance(partial_hypotheses.y_sequence, torch.Tensor) + else partial_hypotheses.y_sequence + ) + if partial_hypotheses.dec_state is not None: + hypothesis.dec_state = self.decoder.batch_concat_states([partial_hypotheses.dec_state]) + hypothesis.dec_state = _states_to_device(hypothesis.dec_state, x.device) + + if self.preserve_alignments: + # Alignments is a 2-dimensional dangling list representing T x U + hypothesis.alignments = [[]] + + if self.preserve_frame_confidence: + hypothesis.frame_confidence = [[]] + + # For timestep t in X_t + duration = 1 + + time_idx = 0 + while time_idx < out_len: + # Extract encoder embedding at timestep t + # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D] + f = x.narrow(dim=0, start=time_idx, length=1) + + # Setup exit flags and counter + not_blank = True + symbols_added = 0 + + need_loop = True + # While blank is not predicted, or we dont run out of max symbols per timestep + while need_loop and (self.max_symbols is None or symbols_added < self.max_symbols): + # In the first timestep, we initialize the network with RNNT Blank + # In later timesteps, we provide previous predicted label as input. + if hypothesis.last_token is None and hypothesis.dec_state is None: + last_label = self._SOS + else: + last_label = label_collate([[hypothesis.last_token]]) + + # Perform prediction network and joint network steps. + g, hidden_prime = self._pred_step(last_label, hypothesis.dec_state) + # If preserving per-frame confidence, log_normalize must be true + logits = self._joint_step(f, g, log_normalize=False) + logp = logits[0, 0, 0, : -len(self.durations)] + if self.preserve_frame_confidence: + logp = torch.log_softmax(logp, -1) + + duration_logp = torch.softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + del g + + # torch.max(0) op doesnt exist for FP 16. + if logp.dtype != torch.float32: + logp = logp.float() + + # get index k, of max prob + v, k = logp.max(0) + k = k.item() # K is the label at timestep t_s in inner loop, s >= 0. + + d_v, d_k = duration_logp.max(0) + d_k = d_k.item() + + skip = self.durations[d_k] + + if self.preserve_alignments: + # insert logprobs into last timestep + hypothesis.alignments[-1].append((logp.to('cpu'), torch.tensor(k, dtype=torch.int32))) + + if self.preserve_frame_confidence: + # insert confidence into last timestep + hypothesis.frame_confidence[-1].append(self._get_confidence(logp)) + + del logp + + # If blank token is predicted, exit inner loop, move onto next timestep t + if k == self._blank_index: + not_blank = False + if skip == 0: + skip = 1 + + if self.preserve_alignments: + # convert Ti-th logits into a torch array + hypothesis.alignments.append([]) # blank buffer for next timestep + + if self.preserve_frame_confidence: + hypothesis.frame_confidence.append([]) # blank buffer for next timestep + else: + # Append token to label set, update RNN state. + hypothesis.y_sequence.append(k) + hypothesis.score += float(v) + hypothesis.timestep.append(time_idx) + hypothesis.dec_state = hidden_prime + hypothesis.last_token = k + + # Increment token counter. + symbols_added += 1 + time_idx += skip + need_loop = skip == 0 + + if symbols_added == self.max_symbols: + time_idx += 1 + + # Remove trailing empty list of Alignments + if self.preserve_alignments: + if len(hypothesis.alignments[-1]) == 0: + del hypothesis.alignments[-1] + + # Remove trailing empty list of per-frame confidence + if self.preserve_frame_confidence: + if len(hypothesis.frame_confidence[-1]) == 0: + del hypothesis.frame_confidence[-1] + + # Unpack the hidden states + hypothesis.dec_state = self.decoder.batch_select_state(hypothesis.dec_state, 0) + + return hypothesis + + +class GreedyBatchedTDTRNNTInfer(_GreedyRNNTInfer): + """A batch level greedy transducer decoder. + Batch level greedy decoding, performed auto-repressively. + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Must be len(vocabulary) for multi-blank RNNTs. + durations: a list containing durations. + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + preserve_alignments: Bool flag which preserves the history of alignments generated during + greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of + Tuple(Tensor (of length V + 1 + num-big-blanks), Tensor(scalar, label after argmax)). + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. + U is the number of target tokens for the current timestep Ti. + preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated + during greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of List of floats. + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. + U is the number of target tokens for the current timestep Ti. + confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence scores. + name: The method name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using normalized entropy of a log-likelihood vector. + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the temperature should comply the following inequality: + 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renui' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the temperature equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. + """ + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + durations: List[int], + max_symbols_per_step: Optional[int] = None, + preserve_alignments: bool = False, + preserve_frame_confidence: bool = False, + confidence_method_cfg: Optional[DictConfig] = None, + ): + super().__init__( + decoder_model=decoder_model, + joint_model=joint_model, + blank_index=blank_index, + max_symbols_per_step=max_symbols_per_step, + preserve_alignments=preserve_alignments, + preserve_frame_confidence=preserve_frame_confidence, + confidence_method_cfg=confidence_method_cfg, + ) + self.durations = durations + + # Depending on availability of `blank_as_pad` support + # switch between more efficient batch decoding technique + if self.decoder.blank_as_pad: + self._greedy_decode = self._greedy_decode_blank_as_pad + else: + self._greedy_decode = self._greedy_decode_masked + + @typecheck() + def forward( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + Args: + encoder_output: A tensor of size (batch, features, timesteps). + encoded_lengths: list of int representing the length of each sequence + output sequence. + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.inference_mode(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + logitlen = encoded_lengths + + self.decoder.eval() + self.joint.eval() + + with self.decoder.as_frozen(), self.joint.as_frozen(): + inseq = encoder_output # [B, T, D] + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) + + # Pack the hypotheses results + packed_result = pack_hypotheses(hypotheses, logitlen) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (packed_result,) + + def _greedy_decode_blank_as_pad( + self, + x: torch.Tensor, + out_len: torch.Tensor, + device: torch.device, + partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + ): + if partial_hypotheses is not None: + raise NotImplementedError("`partial_hypotheses` support is not supported") + + with torch.inference_mode(): + # x: [B, T, D] + # out_len: [B] + # device: torch.device + + # Initialize list of Hypothesis + batchsize = x.shape[0] + hypotheses = [ + rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize) + ] + + # Initialize Hidden state matrix (shared by entire batch) + hidden = None + + # If alignments need to be preserved, register a danling list to hold the values + if self.preserve_alignments: + # alignments is a 3-dimensional dangling list representing B x T x U + for hyp in hypotheses: + hyp.alignments = [[]] + + # If confidence scores need to be preserved, register a danling list to hold the values + if self.preserve_frame_confidence: + # frame_confidence is a 3-dimensional dangling list representing B x T x U + for hyp in hypotheses: + hyp.frame_confidence = [[]] + hyp.y_3best = [[]] + hyp.frame_confidence_3best = [[[]]] + hyp.logp = [[]] + + # Last Label buffer + Last Label without blank buffer + # batch level equivalent of the last_label + last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) + + # Mask buffers + blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # mask for if the utterance in the batch should stay in the same frame. +# stay_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # Get max sequence length + max_out_len = out_len.max() + + skip = 1 + for time_idx in range(max_out_len): + if skip > 1: + skip -= 1 + continue + f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + + need_to_stay = True + symbols_added = 0 + + # Reset blank mask + blank_mask.mul_(False) + + # Update blank mask with time mask + # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) + # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len + blank_mask = time_idx >= out_len + + # Start inner loop + while need_to_stay and (self.max_symbols is None or symbols_added < self.max_symbols): + # Batch prediction and joint network steps + # If very first prediction step, submit SOS tag (blank) to pred_step. + # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state + if time_idx == 0 and symbols_added == 0 and hidden is None: + g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) + else: + # Perform batch step prediction of decoder, getting new states and scores ("g") + g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) + + # Batched joint step - Output = [B, V + 1 + num-big-blanks] + # If preserving per-frame confidence, log_normalize must be true + joined = self._joint_step(f, g, log_normalize=None) + logp = joined[:, 0, 0, :-len(self.durations)] + duration_logp = joined[:, 0, 0, -len(self.durations):] + + if logp.dtype != torch.float32: + logp = logp.float() + duration_logp = duration_logp.float() + + # Get index k, of max prob for batch + v, k = logp.max(1) + dv, dk = duration_logp.max(1) + + skip = self.durations[int(torch.min(dk))] + + if blank_mask.all(): +# print("SKIP is", skip) + if skip == 0: + skip = 1 + need_to_stay = skip == 0 + del g + + # Update blank mask with current predicted blanks + # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) + k_is_blank = k == self._blank_index + blank_mask.bitwise_or_(k_is_blank) + + del k_is_blank + del logp, duration_logp + + # If all samples predict / have predicted prior blanks, exit loop early + # This is equivalent to if single sample predicted k + if not blank_mask.all(): + # Collect batch indices where blanks occurred now/past + blank_indices = (blank_mask == 1).nonzero(as_tuple=False) + + # Recover prior state for all samples which predicted blank now/past + if hidden is not None: + # LSTM has 2 states + hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) + + elif len(blank_indices) > 0 and hidden is None: + # Reset state if there were some blank and other non-blank predictions in batch + # Original state is filled with zeros so we just multiply + # LSTM has 2 states + hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0) + + # Recover prior predicted label for all samples which predicted blank now/past + k[blank_indices] = last_label[blank_indices, 0] + + # Update new label and hidden state for next iteration + last_label = k.clone().view(-1, 1) + hidden = hidden_prime + + # Update predicted labels, accounting for time mask + # If blank was predicted even once, now or in the past, + # Force the current predicted label to also be blank + # This ensures that blanks propogate across all timesteps + # once they have occured (normally stopping condition of sample level loop). + for kidx, ki in enumerate(k): + if blank_mask[kidx] == 0: + hypotheses[kidx].y_sequence.append(ki) + hypotheses[kidx].timestep.append(time_idx) + hypotheses[kidx].score += float(v[kidx]) + + symbols_added += 1 + + # Remove trailing empty list of alignments at T_{am-len} x Uj + if self.preserve_alignments: + for batch_idx in range(batchsize): + if len(hypotheses[batch_idx].alignments[-1]) == 0: + del hypotheses[batch_idx].alignments[-1] + + # Remove trailing empty list of confidence scores at T_{am-len} x Uj + if self.preserve_frame_confidence: + for batch_idx in range(batchsize): + if len(hypotheses[batch_idx].frame_confidence[-1]) == 0: + del hypotheses[batch_idx].frame_confidence[-1] + del hypotheses[batch_idx].y_3best[-1] + del hypotheses[batch_idx].frame_confidence_3best[-1] + del hypotheses[batch_idx].logp[-1] + + # Preserve states + for batch_idx in range(batchsize): + hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx) + + return hypotheses + + def _greedy_decode_masked( + self, + x: torch.Tensor, + out_len: torch.Tensor, + device: torch.device, + partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + ): + if partial_hypotheses is not None: + raise NotImplementedError("`partial_hypotheses` support is not supported") + + # x: [B, T, D] + # out_len: [B] + # device: torch.device + + # Initialize state + batchsize = x.shape[0] + hypotheses = [ + rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize) + ] + + # Initialize Hidden state matrix (shared by entire batch) + hidden = None + + # If alignments need to be preserved, register a danling list to hold the values + if self.preserve_alignments: + # alignments is a 3-dimensional dangling list representing B x T x U + for hyp in hypotheses: + hyp.alignments = [[]] + else: + alignments = None + + # If confidence scores need to be preserved, register a danling list to hold the values + if self.preserve_frame_confidence: + # frame_confidence is a 3-dimensional dangling list representing B x T x U + for hyp in hypotheses: + hyp.frame_confidence = [[]] + + # Last Label buffer + Last Label without blank buffer + # batch level equivalent of the last_label + last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) + last_label_without_blank = last_label.clone() + + # Mask buffers + blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # Get max sequence length + max_out_len = out_len.max() + + with torch.inference_mode(): + for time_idx in range(max_out_len): + f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + + # Prepare t timestamp batch variables + not_blank = True + symbols_added = 0 + + # Reset blank mask + blank_mask.mul_(False) + + # Update blank mask with time mask + # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) + # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len + blank_mask = time_idx >= out_len + + # Start inner loop + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + # Batch prediction and joint network steps + # If very first prediction step, submit SOS tag (blank) to pred_step. + # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state + if time_idx == 0 and symbols_added == 0 and hidden is None: + g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) + else: + # Set a dummy label for the blank value + # This value will be overwritten by "blank" again the last label update below + # This is done as vocabulary of prediction network does not contain "blank" token of RNNT + last_label_without_blank_mask = last_label >= self._blank_index + last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label + last_label_without_blank[~last_label_without_blank_mask] = last_label[ + ~last_label_without_blank_mask + ] + + # Perform batch step prediction of decoder, getting new states and scores ("g") + g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize) + + # Batched joint step - Output = [B, V + 1 + num-big-blanks] + # If preserving per-frame confidence, log_normalize must be true + logp = self._joint_step(f, g, log_normalize=True if self.preserve_frame_confidence else None)[ + :, 0, 0, : + ] + + if logp.dtype != torch.float32: + logp = logp.float() + + # Get index k, of max prob for batch + v, k = logp.max(1) + del g + + # Update blank mask with current predicted blanks + # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) + k_is_blank = k == self._blank_index + blank_mask.bitwise_or_(k_is_blank) + + # If preserving alignments, check if sequence length of sample has been reached + # before adding alignment + if self.preserve_alignments: + # Insert logprobs into last timestep per sample + logp_vals = logp.to('cpu') + logp_ids = logp_vals.max(1)[1] + for batch_idx in range(batchsize): + if time_idx < out_len[batch_idx]: + hypotheses[batch_idx].alignments[-1].append( + (logp_vals[batch_idx], logp_ids[batch_idx]) + ) + del logp_vals + + # If preserving per-frame confidence, check if sequence length of sample has been reached + # before adding confidence scores + if self.preserve_frame_confidence: + # Insert probabilities into last timestep per sample + confidence = self._get_confidence(logp) + for batch_idx in range(batchsize): + if time_idx < out_len[batch_idx]: + hypotheses[batch_idx].frame_confidence[-1].append(confidence[batch_idx]) + del logp + + # If all samples predict / have predicted prior blanks, exit loop early + # This is equivalent to if single sample predicted k + if blank_mask.all(): + not_blank = False + + # If preserving alignments, convert the current Uj alignments into a torch.Tensor + # Then preserve U at current timestep Ti + # Finally, forward the timestep history to Ti+1 for that sample + # All of this should only be done iff the current time index <= sample-level AM length. + # Otherwise ignore and move to next sample / next timestep. + if self.preserve_alignments: + + # convert Ti-th logits into a torch array + for batch_idx in range(batchsize): + + # this checks if current timestep <= sample-level AM length + # If current timestep > sample-level AM length, no alignments will be added + # Therefore the list of Uj alignments is empty here. + if len(hypotheses[batch_idx].alignments[-1]) > 0: + hypotheses[batch_idx].alignments.append([]) # blank buffer for next timestep + + # Do the same if preserving per-frame confidence + if self.preserve_frame_confidence: + + for batch_idx in range(batchsize): + if len(hypotheses[batch_idx].frame_confidence[-1]) > 0: + hypotheses[batch_idx].frame_confidence.append([]) # blank buffer for next timestep + else: + # Collect batch indices where blanks occurred now/past + blank_indices = (blank_mask == 1).nonzero(as_tuple=False) + + # Recover prior state for all samples which predicted blank now/past + if hidden is not None: + # LSTM has 2 states + hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) + + elif len(blank_indices) > 0 and hidden is None: + # Reset state if there were some blank and other non-blank predictions in batch + # Original state is filled with zeros so we just multiply + # LSTM has 2 states + hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0) + + # Recover prior predicted label for all samples which predicted blank now/past + k[blank_indices] = last_label[blank_indices, 0] + + # Update new label and hidden state for next iteration + last_label = k.view(-1, 1) + hidden = hidden_prime + + # Update predicted labels, accounting for time mask + # If blank was predicted even once, now or in the past, + # Force the current predicted label to also be blank + # This ensures that blanks propogate across all timesteps + # once they have occured (normally stopping condition of sample level loop). + for kidx, ki in enumerate(k): + if blank_mask[kidx] == 0: + hypotheses[kidx].y_sequence.append(ki) + hypotheses[kidx].timestep.append(time_idx) + hypotheses[kidx].score += float(v[kidx]) + + symbols_added += 1 + + # Remove trailing empty list of alignments at T_{am-len} x Uj + if self.preserve_alignments: + for batch_idx in range(batchsize): + if len(hypotheses[batch_idx].alignments[-1]) == 0: + del hypotheses[batch_idx].alignments[-1] + + # Remove trailing empty list of confidence scores at T_{am-len} x Uj + if self.preserve_frame_confidence: + for batch_idx in range(batchsize): + if len(hypotheses[batch_idx].frame_confidence[-1]) == 0: + del hypotheses[batch_idx].frame_confidence[-1] + + # Preserve states + for batch_idx in range(batchsize): + hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx) + + return hypotheses + + diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 7764649bf1fa..4878974926a4 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -18,9 +18,9 @@ import pytest import torch -from nemo.collections.asr.losses.rnnt import MultiblankRNNTLossPytorch, RNNTLossPytorch +from nemo.collections.asr.losses.rnnt import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTRNNTLossPytorch from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy -from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import MultiblankRNNTLossNumba, RNNTLossNumba +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import MultiblankRNNTLossNumba, RNNTLossNumba, TDTRNNTLossNumba from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ @@ -460,6 +460,34 @@ def zero_grad(): assert np.allclose(pt_grads1_p_2, np_grads1 + np_grads2, atol=1e-5) +class TestTDTRNNTLoss: + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_randomized_act_label(self, device): + if device == 'cuda': + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + B, T, U, V = 2, 3, 3, 2 # here V is number of non blank labels + durations = [0, 1] + sigma = 0.05 + + acts = torch.rand([B, T, U, V + 1 + len(durations)]) + labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] + + fn_pt = TDTRNNTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + fn_ag = TDTRNNTLossPytorch( + blank=V, reduction='sum', durations=durations, sigma=sigma + ) # ag for automatic gradient computation + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + + assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt-blank costs mismatch." + assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "tdt-blank gradient mismatch." + + + + class TestMultiblankRNNTLoss: @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 5b30489f846c..6c1fcdd22401 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -363,6 +363,57 @@ def test_multiblank_rnnt_greedy_decoding(self, greedy_class): with torch.no_grad(): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + + + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyMultiblankRNNTInfer, greedy_decode.GreedyBatchedMultiblankRNNTInfer], + ) + def test_multiblank_rnnt_greedy_decoding(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + big_blank_durations = [2, 4] + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint( + jointnet_cfg, vocab_size, vocabulary=token_list, num_extra_outputs=len(big_blank_durations) + ) + + greedy = greedy_class( + decoder, + joint_net, + blank_index=len(token_list), + big_blank_durations=big_blank_durations, + max_symbols_per_step=5, + ) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) From 952d2614e66e4f9599f51622577f4896b017d347 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 May 2023 18:50:15 +0000 Subject: [PATCH 02/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/asr/losses/rnnt.py | 7 ++----- nemo/collections/asr/losses/rnnt_pytorch.py | 1 - nemo/collections/asr/metrics/rnnt_wer.py | 14 ++++++++------ .../asr/parts/numba/rnnt_loss/__init__.py | 6 +++++- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 7 +------ .../rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py | 1 - .../asr/parts/submodules/rnnt_greedy_decoding.py | 10 ++++------ .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 8 +++++--- 8 files changed, 25 insertions(+), 29 deletions(-) diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 8fb6909945eb..f5fce7ab289f 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -34,8 +34,7 @@ import torch from omegaconf import DictConfig, OmegaConf -from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch -from nemo.collections.asr.losses.rnnt_pytorch import TDTRNNTLossPytorch +from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTRNNTLossPytorch from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE @@ -49,8 +48,7 @@ WARP_RNNT_AVAILABLE = False try: - from nemo.collections.asr.parts.numba.rnnt_loss import MultiblankRNNTLossNumba, RNNTLossNumba - from nemo.collections.asr.parts.numba.rnnt_loss import TDTRNNTLossNumba + from nemo.collections.asr.parts.numba.rnnt_loss import MultiblankRNNTLossNumba, RNNTLossNumba, TDTRNNTLossNumba NUMBA_RNNT_AVAILABLE = True except (ImportError, ModuleNotFoundError): @@ -253,7 +251,6 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) loss_func = TDTRNNTLossPytorch(blank=blank_idx, durations=durations, reduction='none', sigma=sigma) _warn_unused_additional_kwargs(loss_name, loss_kwargs) - else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index b20245c0ac68..c8891f11f6ec 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -112,7 +112,6 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): return log_prob - class TDTRNNTLossPytorch(Loss): @property def input_types(self): diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 75901fad1276..90611c8f4ba9 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -214,7 +214,6 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.durations is not None: assert blank_id != 0, 'blank_id must equal len(non_blank_vocabs) for multi-blank RNN-T models' - if self.big_blank_durations is not None: if blank_id == 0: raise ValueError("blank_id must equal len(vocabs) for multi-blank RNN-T models") @@ -265,7 +264,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): blank_index=self.blank_id, durations=self.durations, max_symbols_per_step=( - self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + self.cfg.greedy.get('max_symbols', None) + or self.cfg.greedy.get('max_symbols_per_step', None) ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, @@ -277,7 +277,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): joint_model=joint, blank_index=self.blank_id, max_symbols_per_step=( - self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + self.cfg.greedy.get('max_symbols', None) + or self.cfg.greedy.get('max_symbols_per_step', None) ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, @@ -306,7 +307,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): blank_index=self.blank_id, durations=self.durations, max_symbols_per_step=( - self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + self.cfg.greedy.get('max_symbols', None) + or self.cfg.greedy.get('max_symbols_per_step', None) ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, @@ -318,14 +320,14 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): joint_model=joint, blank_index=self.blank_id, max_symbols_per_step=( - self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) + self.cfg.greedy.get('max_symbols', None) + or self.cfg.greedy.get('max_symbols_per_step', None) ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, ) - else: self.decoding = greedy_decode.GreedyBatchedMultiblankRNNTInfer( decoder_model=decoder, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py b/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py index eecfaf785b3e..cfded44c78ba 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py @@ -13,4 +13,8 @@ # limitations under the License. from nemo.collections.asr.parts.numba.rnnt_loss.rnnt import rnnt_loss_cpu, rnnt_loss_gpu -from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import MultiblankRNNTLossNumba, RNNTLossNumba, TDTRNNTLossNumba +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import ( + MultiblankRNNTLossNumba, + RNNTLossNumba, + TDTRNNTLossNumba, +) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index dd92db87c879..fc479e5c8a74 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -189,8 +189,6 @@ def backward(ctx, grad_output): ) - - class _MultiblankRNNTNumba(Function): """ Numba class for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf) @@ -383,8 +381,6 @@ def tdt_rnnt_loss( return _TDTRNNTNumba.apply(acts, labels, act_lens, label_lens, blank, durations, reduction, fastemit_lambda, clamp) - - class RNNTLossNumba(Module): """ Parameters: @@ -501,6 +497,7 @@ def forward(self, acts, labels, act_lens, label_lens): self.sigma, ) + class TDTRNNTLossNumba(Module): """ Parameters: @@ -580,8 +577,6 @@ def forward(self, acts, labels, act_lens, label_lens): ) - - def check_type(var, t, name): if var.dtype is not t: raise TypeError("{} must be {}".format(name, t)) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 36cbb94a0eb6..0f7eb65b2544 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -879,7 +879,6 @@ def compute_multiblank_grad_kernel( idx += GPU_RNNT_THREAD_SIZE - @cuda.jit(device=True, inline=True) def logp_duration(acts: torch.Tensor, maxT: int, maxU: int, num_durations: int, mb: int, t: int, u: int, v: int): col = (mb * maxT + t) * maxU + u diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index ad3b010f29be..939e45bed006 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2627,7 +2627,7 @@ def _greedy_decode_blank_as_pad( blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) # mask for if the utterance in the batch should stay in the same frame. -# stay_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + # stay_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) # Get max sequence length max_out_len = out_len.max() @@ -2664,8 +2664,8 @@ def _greedy_decode_blank_as_pad( # Batched joint step - Output = [B, V + 1 + num-big-blanks] # If preserving per-frame confidence, log_normalize must be true joined = self._joint_step(f, g, log_normalize=None) - logp = joined[:, 0, 0, :-len(self.durations)] - duration_logp = joined[:, 0, 0, -len(self.durations):] + logp = joined[:, 0, 0, : -len(self.durations)] + duration_logp = joined[:, 0, 0, -len(self.durations) :] if logp.dtype != torch.float32: logp = logp.float() @@ -2678,7 +2678,7 @@ def _greedy_decode_blank_as_pad( skip = self.durations[int(torch.min(dk))] if blank_mask.all(): -# print("SKIP is", skip) + # print("SKIP is", skip) if skip == 0: skip = 1 need_to_stay = skip == 0 @@ -2954,5 +2954,3 @@ def _greedy_decode_masked( hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx) return hypotheses - - diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 4878974926a4..1a0a5a4fb507 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -20,7 +20,11 @@ from nemo.collections.asr.losses.rnnt import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTRNNTLossPytorch from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy -from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import MultiblankRNNTLossNumba, RNNTLossNumba, TDTRNNTLossNumba +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import ( + MultiblankRNNTLossNumba, + RNNTLossNumba, + TDTRNNTLossNumba, +) from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ @@ -486,8 +490,6 @@ def test_case_randomized_act_label(self, device): assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "tdt-blank gradient mismatch." - - class TestMultiblankRNNTLoss: @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) From a6893a64de36b82d93316968ec8c57eb4613c4ea Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 4 May 2023 15:42:04 -0400 Subject: [PATCH 03/40] TDT PR WIP Signed-off-by: Hainan Xu --- .../conformer_tdt_transducer_bpe.yaml | 28 +++------- nemo/collections/asr/losses/rnnt.py | 8 ++- nemo/collections/asr/losses/rnnt_pytorch.py | 4 +- nemo/collections/asr/metrics/rnnt_wer.py | 12 +++- nemo/collections/asr/models/rnnt_models.py | 6 +- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 55 +++++++++---------- 6 files changed, 56 insertions(+), 57 deletions(-) diff --git a/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml b/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml index d8c814d185ee..684732e9406f 100644 --- a/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml +++ b/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml @@ -1,23 +1,6 @@ # It contains the default values for training an TDT Conformer-Transducer ASR model with stateless decoders, large size (~120M) with Transducer loss and sub-word encoding. -# Architecture and training config: -# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective -# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. -# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. -# -# +-------------+---------+---------+----------+--------------+--------------------------+ -# | Model | d_model | n_heads | n_layers | weight_decay | pred_hidden/joint_hidden | -# +=============+=========+========+===========+==============+==========================+ -# | Small (14M)| 176 | 4 | 16 | 0.0 | 320 | -# +-------------+---------+--------+-----------+--------------+--------------------------+ -# | Medium (32M)| 256 | 4 | 16 | 1e-3 | 640 | -# +-------------+---------+--------+-----------+--------------+--------------------------+ -# | Large (120M)| 512 | 8 | 17 | 1e-3 | 640 | -# +-----------------------------------------------------------+--------------------------+ -# - -# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer -# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html +# You can find detailed info about TDT models at https://arxiv.org/abs/2304.06795. name: "TDT-Conformer-Transducer-BPE" @@ -185,16 +168,19 @@ model: alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding loss: + # This is the main different between a TDT model and a conventional RNNT model -- the loss function. loss_name: "tdt_rnnt" tdt_rnnt_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 # You may enable FastEmit to reduce the latency of the model for streaming - fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + fastemit_lambda: 0.001 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + # refer to https://arxiv.org/abs/2304.06795 for the meaning of the following three configs. durations: [0, 1, 2, 3, 4] - sigma: 0.05 - omega: 0.0 # weight for regular RNN-T loss + sigma: 0.05 # hyper-param for under-normalization. + omega: 0.0 # weight for regular RNN-T loss. # Adds Gaussian noise to the gradients of the decoder to avoid overfitting variational_noise: diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 8fb6909945eb..5d749c2f315a 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -319,7 +319,13 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str = Args: num_classes: Number of target classes for the joint network to predict. - (Excluding the RNN-T blank token). + In all cases (conventional RNNT, multi-blank RNNT, and TDT model), this equals the token-id + for the standard "blank" symbol. In particular, say V is the number of non-blank tokens in + the vocabulary, then in the case of, + standard RNNT: num_classes = V + multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before + standard blank, and the standard blank is the last symbol in the vocab) + TDT: num_classes = V. Note, V here does not include any of the "duration outputs". reduction: Type of reduction to perform on loss. Possible values are `mean_batch`, 'mean_volume`, `mean`, `sum` or None. diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index b20245c0ac68..274579c650ae 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -209,9 +209,7 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens log_probs = [] for b in range(B): - tt = torch.Tensor([-1000.0]) - tt = tt.cuda() - tt = tt[0] + tt = torch.Tensor([-1000.0]).cuda()[0] for n, l in enumerate(self.durations): if act_lens[b] - l >= 0 and l > 0: bb = ( diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 75901fad1276..63501f3a3b34 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -212,12 +212,18 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.word_seperator = self.cfg.get('word_seperator', ' ') if self.durations is not None: - assert blank_id != 0, 'blank_id must equal len(non_blank_vocabs) for multi-blank RNN-T models' - + if blank_id == 0: + raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") + if self.big_blank_durations is not None: + raise ValueError("duration and big_blank_durations can't both be not None") + if self.cfg.strategy not in ['greedy', 'greedy_batch']: + raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models") if self.big_blank_durations is not None: if blank_id == 0: raise ValueError("blank_id must equal len(vocabs) for multi-blank RNN-T models") + if self.cfg.strategy not in ['greedy', 'greedy_batch']: + raise ValueError("currently only greedy and greedy_batch inference is supported for multi-blank models") possible_strategies = ['greedy', 'greedy_batch', 'beam', 'tsd', 'alsd', 'maes'] if self.cfg.strategy not in possible_strategies: @@ -1122,7 +1128,7 @@ def decode_tokens_to_str(self, tokens: List[int]) -> str: Args: tokens: List of int representing the token ids. - eturns: + Returns: A decoded string. """ hypothesis = ''.join(self.decode_ids_to_tokens(tokens)) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 9be9b9d9eb85..a9a4a71499f7 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -71,8 +71,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup RNNT Loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) + num_classes = self.joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank + if loss_name == 'tdt_rnnt': + num_classes = num_classes - self.joint.num_extra_outputs + self.loss = RNNTLoss( - num_classes=self.joint.num_classes_with_blank - 1 - self.joint.num_extra_outputs, + num_classes=num_classes, loss_name=loss_name, loss_kwargs=loss_kwargs, reduction=self.cfg.get("rnnt_reduction", "mean_batch"), diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 4878974926a4..d27caf7d915b 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -460,34 +460,6 @@ def zero_grad(): assert np.allclose(pt_grads1_p_2, np_grads1 + np_grads2, atol=1e-5) -class TestTDTRNNTLoss: - @pytest.mark.unit - @pytest.mark.parametrize('device', DEVICES) - def test_case_randomized_act_label(self, device): - if device == 'cuda': - numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) - - B, T, U, V = 2, 3, 3, 2 # here V is number of non blank labels - durations = [0, 1] - sigma = 0.05 - - acts = torch.rand([B, T, U, V + 1 + len(durations)]) - labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] - - fn_pt = TDTRNNTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) - pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) - - fn_ag = TDTRNNTLossPytorch( - blank=V, reduction='sum', durations=durations, sigma=sigma - ) # ag for automatic gradient computation - ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) - - assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt-blank costs mismatch." - assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "tdt-blank gradient mismatch." - - - - class TestMultiblankRNNTLoss: @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) @@ -522,5 +494,32 @@ def test_case_randomized_act_label(self, device): assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "multi-blank gradient mismatch." +class TestTDTRNNTLoss: + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_randomized_act_label(self, device): + if device == 'cuda': + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + B, T, U, V = 4, 8, 4, 8 # here V is number of non blank labels + durations = [0, 1, 2, 3, 4, 5] + sigma = 0.05 + + acts = torch.rand([B, T, U, V + 1 + len(durations)]) + labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] + + fn_pt = TDTRNNTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + fn_ag = TDTRNNTLossPytorch( + blank=V, reduction='sum', durations=durations, sigma=sigma + ) # ag for automatic gradient computation + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + + assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt-blank costs mismatch." + assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "tdt-blank gradient mismatch." + + + if __name__ == "__main__": pytest.main([__file__]) From 657ed9d2d8702717fdec6d8ec1fe318a9db4034e Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 4 May 2023 16:04:14 -0400 Subject: [PATCH 04/40] TDT PR WIP Signed-off-by: Hainan Xu --- .../collections/asr/parts/numba/rnnt_loss/rnnt.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index 20ac579d3196..ae0a1c08d93f 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -254,28 +254,27 @@ def tdt_rnnt_loss_gpu( omega: float, ): """ - Wrapper method for accessing GPU Multi-blank RNNT loss (https://arxiv.org/pdf/2211.03541.pdf). + Wrapper method for accessing GPU TDT loss (https://arxiv.org/abs/2304.06795). CUDA implementation ported from [HawkAaron/warp-transducer](https://github.com/HawkAaron/warp-transducer). Args: - acts: Activation tensor of shape [B, T, U, V+1]. + label_acts: Activation tensor of shape [B, T, U, V], where V includes the blank symbol. + duration_acts: Activation tensor of shape [B, T, U, D], where D is the number of durations. labels: Ground truth labels of shape [B, U]. input_lengths: Lengths of the acoustic sequence as a vector of ints [B]. label_lengths: Lengths of the target sequence as a vector of ints [B]. costs: Zero vector of length [B] in which costs will be set. - grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set. + label_grads: Zero tensor of shape [B, T, U, V] where the gradient to label_acts will be set. + duration_grads: Zero tensor of shape [B, T, U, D] where the gradient to duration_acts will be set. blank_label: Index of the standard blank token in the vocabulary. - durations: A list of supported durations for big blank symbols - in the model, e.g. [2, 4, 8]. Note we only include durations for ``big - blanks'' here and it should not include 1 for the standard blank. - Those big blanks have vocabulary indices after the standard blank index. + durations: A list of supported durations for TDT. Must include 0 and 1. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of threads for OpenMP. sigma: logit-undernormalization weight used in the multi-blank model. Refer to - the multi-blank paper https://arxiv.org/pdf/2211.03541 for detailed explanations. + the multi-blank paper https://arxiv.org/abs/2304.06795 for detailed explanations. omega: weight for regular RNN-T loss """ minibatch_size = label_acts.shape[0] From 8b01b4262fb161836feecad7e0927cc542a91dbd Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 4 May 2023 16:14:09 -0400 Subject: [PATCH 05/40] TDT PR WIP Signed-off-by: Hainan Xu --- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index fc479e5c8a74..4af12bf10786 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -93,7 +93,7 @@ def backward(ctx, grad_output): class _TDTRNNTNumba(Function): """ - Numba class for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf) + Numba class for TDT loss (https://arxiv.org/abs/2304.06795) """ @staticmethod @@ -113,12 +113,12 @@ def forward( omega, ): """ - durations: list of durations for multi-blank transducer, e.g. - [2, 4, 8]. + durations: list of durations for TDT model, must include 0 and 1, e.g. + [0, 1, 2, 3, 4]. sigma: hyper-parameter for logit under-normalization method for training - multi-blank transducers. Recommended value 0.05. + TDT models. Recommended value 0.05. omega: weight for standard RNN-T loss - Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for + Refer to https://arxiv.org/abs/2304.06795 for detailed explanations for the above parameters; For other parameters for this class, refer to comment for class _RNNTNumba """ @@ -132,6 +132,7 @@ def forward( loss_func = rnnt.tdt_rnnt_loss_gpu else: exit(-1) + label_grads = torch.zeros_like(label_acts) if label_acts.requires_grad else None duration_grads = torch.zeros_like(duration_acts) if duration_acts.requires_grad else None minibatch_size = label_acts.size(0) @@ -347,18 +348,18 @@ def tdt_rnnt_loss( clamp: float = 0.0, ): """ - TDT RNN Transducer (https://arxiv.org/pdf/2211.03541.pdf) Loss (functional form) + TDT RNN Transducer (https://arxiv.org/abs/2304.06795) Loss (functional form) Args: acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded act_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example blank (int): standard blank label. - durations: list of durations for multi-blank transducer, e.g. - [2, 4, 8]. + durations: list of durations for TDT model, e.g. + [0,1,2,3,4]. sigma: hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value 0.05. - Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for + Refer to https://arxiv.org/abs/2304.06795 for detailed explanations for the last two params. reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, From 51eed215b5301d906fbc59b35e31783b2ac097ba Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 4 May 2023 20:21:00 -0400 Subject: [PATCH 06/40] TDT WIP Signed-off-by: Hainan Xu --- examples/asr/speech_to_text_eval.py | 2 +- nemo/collections/asr/metrics/rnnt_wer.py | 20 ++++++++++---------- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 11 +++++++++-- nemo/collections/asr/models/rnnt_models.py | 1 + 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index d846157b6513..a0ff02c7996e 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization from nemo.core.config import hydra_runner diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 5cbc3d9b15ef..ef175d578296 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -526,15 +526,12 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp # RNN-T sample level is already preprocessed by implicit RNNT decoding # Simply remove any blank and possibly big blank tokens - if self.blank_id != 0: - # use < here works both for standard and multi-blank RNN-T. - prediction = [p for p in prediction if p < self.blank_id] - elif self.blank_id != 0 and self.big_blank_durations is not None: + if self.big_blank_durations is not None: # multi-blank RNNT num_extra_outputs = len(self.big_blank_durations) prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs] - - # Simply remove any blank and possibly big blank tokens - else: + elif self.durations is not None: # TDT model. + prediction = [p for p in prediction if p < self.blank_id] + else: # standard RNN-T prediction = [p for p in prediction if p != self.blank_id] # De-tokenize the integer tokens; if not computing timestamps @@ -1102,9 +1099,12 @@ class RNNTDecoding(AbstractRNNTDecoding): def __init__( self, decoding_cfg, decoder, joint, vocabulary, ): - blank_id = ( - len(vocabulary) + joint.num_extra_outputs - ) # we need to ensure blank is the last token in the vocab. This is needed for multi-blank RNN-T models. + # we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT. + blank_id = len(vocabulary) + joint.num_extra_outputs + + if 'durations' in decoding_cfg: # this means it's a TDT model. + blank_id = len(vocabulary) + self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) super(RNNTDecoding, self).__init__( diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index c59b65552842..d69ed9e58984 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -196,11 +196,18 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): """ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): - blank_id = tokenizer.tokenizer.vocab_size + blank_id = tokenizer.tokenizer.vocab_size # RNNT or Multi-blank RNNT. + + # TDT model. + if 'durations' in decoding_cfg: + blank_id = tokenizer.tokenizer.vocab_size + elif 'big_blank_durations' in decoding_cfg: + blank_id = tokenizer.tokenizer.vocab_size + joint.num_extra_outputs + self.tokenizer = tokenizer super(RNNTBPEDecoding, self).__init__( - decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id + joint.num_extra_outputs + decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id ) if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index a9a4a71499f7..8240fbc35875 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -72,6 +72,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) num_classes = self.joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank + if loss_name == 'tdt_rnnt': num_classes = num_classes - self.joint.num_extra_outputs From 452f45b0552c9887e24efee6b206b344d9135395 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 May 2023 00:22:58 +0000 Subject: [PATCH 07/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index a0ff02c7996e..d846157b6513 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization from nemo.core.config import hydra_runner From e3947e16c2e6c569ad182efcb1ebb9478a1edf61 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 4 May 2023 20:35:44 -0400 Subject: [PATCH 08/40] TDT WIP Signed-off-by: Hainan Xu --- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 15 +-------------- .../numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 2 +- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 4af12bf10786..847d372d043b 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -131,7 +131,7 @@ def forward( if is_cuda: loss_func = rnnt.tdt_rnnt_loss_gpu else: - exit(-1) + ValueError("TDT is not yet implemented for non CUDA computation.") label_grads = torch.zeros_like(label_acts) if label_acts.requires_grad else None duration_grads = torch.zeros_like(duration_acts) if duration_acts.requires_grad else None @@ -549,19 +549,6 @@ def forward(self, acts, labels, act_lens, label_lens): label_acts = acts[:, :, :, : -len(self.durations)].contiguous() duration_acts = torch.nn.functional.log_softmax(acts[:, :, :, -len(self.durations) :], dim=-1).contiguous() - # if not acts.is_cuda: - # # Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping - # # *after* we have obtained the gradients of loss(logsoftmax()). - # # This is highly wasteful since it requires a copy of the entire joint tensor which is expensive. - # # CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore - # # can inplace clamp the gradient. - # if self.clamp > 0.0: - # acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp) - # - # # NOTE manually done log_softmax for CPU version, - # # log_softmax is computed within GPU version. - # acts = torch.nn.functional.log_softmax(acts, -1) - return self.loss( label_acts, duration_acts, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index c10fa9fba030..e63b399750d8 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -801,7 +801,7 @@ def score_forward( return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE return self.compute_cost_and_score( - label_acts, duration_acts, None, costs, pad_labels, label_lengths, input_lengths + label_acts, duration_acts, None, None, costs, pad_labels, label_lengths, input_lengths ) def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): From fd9c8ac1c02e9bbe214bcc4fba755c343aad98ee Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 12:38:13 -0400 Subject: [PATCH 09/40] TDT WIP Signed-off-by: Hainan Xu --- nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 847d372d043b..19bcce713798 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -131,7 +131,7 @@ def forward( if is_cuda: loss_func = rnnt.tdt_rnnt_loss_gpu else: - ValueError("TDT is not yet implemented for non CUDA computation.") + raise ValueError("TDT is not yet implemented for non CUDA computation.") label_grads = torch.zeros_like(label_acts) if label_acts.requires_grad else None duration_grads = torch.zeros_like(duration_acts) if duration_acts.requires_grad else None From 23c57599281e255ca69430bb38f84092a08d9096 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 16:34:11 -0400 Subject: [PATCH 10/40] TDT WIP Signed-off-by: Hainan Xu --- .../conformer_tdt_transducer_bpe.yaml | 10 +++---- nemo/collections/asr/losses/rnnt.py | 22 ++++++++-------- nemo/collections/asr/losses/rnnt_pytorch.py | 2 +- nemo/collections/asr/metrics/rnnt_wer.py | 4 +-- nemo/collections/asr/models/rnnt_models.py | 2 +- .../asr/parts/numba/rnnt_loss/__init__.py | 2 +- .../asr/parts/numba/rnnt_loss/rnnt.py | 4 +-- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 26 ++++++++++--------- .../rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 16 +++++++----- .../utils/cuda_utils/gpu_rnnt_kernel.py | 10 +++---- .../parts/submodules/rnnt_greedy_decoding.py | 4 +-- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 10 +++---- 12 files changed, 58 insertions(+), 54 deletions(-) diff --git a/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml b/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml index 684732e9406f..0bfb39771b50 100644 --- a/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml +++ b/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml @@ -1,8 +1,8 @@ -# It contains the default values for training an TDT Conformer-Transducer ASR model with stateless decoders, large size (~120M) with Transducer loss and sub-word encoding. +# This file contains the default values for training an TDT Conformer-Transducer ASR model, large size (~120M) with sub-word encoding. # You can find detailed info about TDT models at https://arxiv.org/abs/2304.06795. -name: "TDT-Conformer-Transducer-BPE" +name: "Conformer-TDT-BPE" model: sample_rate: 16000 @@ -169,9 +169,9 @@ model: loss: # This is the main different between a TDT model and a conventional RNNT model -- the loss function. - loss_name: "tdt_rnnt" + loss_name: "tdt" - tdt_rnnt_kwargs: + tdt_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 # You may enable FastEmit to reduce the latency of the model for streaming fastemit_lambda: 0.001 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. @@ -180,7 +180,7 @@ model: # refer to https://arxiv.org/abs/2304.06795 for the meaning of the following three configs. durations: [0, 1, 2, 3, 4] sigma: 0.05 # hyper-param for under-normalization. - omega: 0.0 # weight for regular RNN-T loss. + omega: 0.1 # weight for regular RNN-T loss. # Adds Gaussian noise to the gradients of the decoder to avoid overfitting variational_noise: diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index cd4ef7102d09..bb89e432215a 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -34,7 +34,7 @@ import torch from omegaconf import DictConfig, OmegaConf -from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTRNNTLossPytorch +from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTLossPytorch from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE @@ -48,7 +48,7 @@ WARP_RNNT_AVAILABLE = False try: - from nemo.collections.asr.parts.numba.rnnt_loss import MultiblankRNNTLossNumba, RNNTLossNumba, TDTRNNTLossNumba + from nemo.collections.asr.parts.numba.rnnt_loss import MultiblankRNNTLossNumba, RNNTLossNumba, TDTLossNumba NUMBA_RNNT_AVAILABLE = True except (ImportError, ModuleNotFoundError): @@ -109,19 +109,19 @@ class RNNTLossConfig: is_available=True, installation_msg="Pure Pytorch implementation of Multiblank RNN-T loss. Slow and for debugging purposes only.", ), - "tdt_rnnt": RNNTLossConfig( - loss_name="tdt_rnnt", + "tdt": RNNTLossConfig( + loss_name="tdt", lib_name="numba", min_version='0.53.0', is_available=NUMBA_RNNT_AVAILABLE, installation_msg=NUMBA_INSTALLATION_MESSAGE, ), - "tdt_rnnt_pytorch": RNNTLossConfig( - loss_name="pytorch", + "tdt_pytorch": RNNTLossConfig( + loss_name="tdt_pytorch", lib_name="torch", min_version='0.0', is_available=True, - installation_msg="Pure Pytorch implementation of TDT RNN-T loss. Slow and for debugging purposes only.", + installation_msg="Pure Pytorch implementation of TDT loss. Slow and for debugging purposes only.", ), } @@ -228,13 +228,13 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) ) _warn_unused_additional_kwargs(loss_name, loss_kwargs) - elif loss_name == 'tdt_rnnt': + elif loss_name == 'tdt': fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) clamp = loss_kwargs.pop('clamp', -1.0) durations = loss_kwargs.pop('durations', None) sigma = loss_kwargs.pop('sigma', 0.0) omega = loss_kwargs.pop('omega', 0.0) - loss_func = TDTRNNTLossNumba( + loss_func = TDTLossNumba( blank=blank_idx, durations=durations, reduction='none', @@ -245,10 +245,10 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) ) _warn_unused_additional_kwargs(loss_name, loss_kwargs) - elif loss_name == 'tdt_rnnt_pytorch': + elif loss_name == 'tdt_pytorch': durations = loss_kwargs.pop('durations', None) sigma = loss_kwargs.pop('sigma', 0.0) - loss_func = TDTRNNTLossPytorch(blank=blank_idx, durations=durations, reduction='none', sigma=sigma) + loss_func = TDTLossPytorch(blank=blank_idx, durations=durations, reduction='none', sigma=sigma) _warn_unused_additional_kwargs(loss_name, loss_kwargs) else: diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index 646ce8297549..45dedc972da1 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -112,7 +112,7 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): return log_prob -class TDTRNNTLossPytorch(Loss): +class TDTLossPytorch(Loss): @property def input_types(self): """Input types definitions for CTCLoss. diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index ef175d578296..b4bfa479988d 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -267,7 +267,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.cfg.strategy == 'greedy': if self.big_blank_durations is None: if self.durations is not None: - self.decoding = greedy_decode.GreedyTDTRNNTInfer( + self.decoding = greedy_decode.GreedyTDTInfer( decoder_model=decoder, joint_model=joint, blank_index=self.blank_id, @@ -310,7 +310,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): elif self.cfg.strategy == 'greedy_batch': if self.big_blank_durations is None: if self.durations is not None: - self.decoding = greedy_decode.GreedyBatchedTDTRNNTInfer( + self.decoding = greedy_decode.GreedyBatchedTDTInfer( decoder_model=decoder, joint_model=joint, blank_index=self.blank_id, diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 8240fbc35875..c54cc8e4ff14 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -73,7 +73,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): num_classes = self.joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank - if loss_name == 'tdt_rnnt': + if loss_name == 'tdt': num_classes = num_classes - self.joint.num_extra_outputs self.loss = RNNTLoss( diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py b/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py index cfded44c78ba..055d7aeb5fd9 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py @@ -16,5 +16,5 @@ from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import ( MultiblankRNNTLossNumba, RNNTLossNumba, - TDTRNNTLossNumba, + TDTLossNumba, ) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index ae0a1c08d93f..118ee88acbfe 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -236,7 +236,7 @@ def rnnt_loss_gpu( return True -def tdt_rnnt_loss_gpu( +def tdt_loss_gpu( label_acts: torch.Tensor, duration_acts: torch.Tensor, labels: torch.Tensor, @@ -310,7 +310,7 @@ def tdt_rnnt_loss_gpu( label_acts, label_acts_shape = rnnt_helper.flatten_tensor(label_acts) duration_acts, duration_acts_shape = rnnt_helper.flatten_tensor(duration_acts) - wrapper = gpu_rnnt.TDTGPURNNT( + wrapper = gpu_rnnt.GPUTDT( minibatch=minibatch_size, maxT=maxT, maxU=maxU, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 19bcce713798..12f49eb023f5 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -34,7 +34,7 @@ from nemo.collections.asr.parts.numba.rnnt_loss import rnnt from nemo.collections.asr.parts.numba.rnnt_loss.utils.cpu_utils import cpu_rnnt -__all__ = ['rnnt_loss', 'RNNTLossNumba', 'MultiblankRNNTLossNumba', 'TDTRNNTLossNumba'] +__all__ = ['rnnt_loss', 'RNNTLossNumba', 'MultiblankRNNTLossNumba', 'TDTLossNumba'] class _RNNTNumba(Function): @@ -91,7 +91,7 @@ def backward(ctx, grad_output): return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None -class _TDTRNNTNumba(Function): +class _TDTNumba(Function): """ Numba class for TDT loss (https://arxiv.org/abs/2304.06795) """ @@ -129,7 +129,7 @@ def forward( raise ValueError("`clamp` must be 0.0 or positive float value.") if is_cuda: - loss_func = rnnt.tdt_rnnt_loss_gpu + loss_func = rnnt.tdt_loss_gpu else: raise ValueError("TDT is not yet implemented for non CUDA computation.") @@ -336,7 +336,7 @@ def multiblank_rnnt_loss( ) -def tdt_rnnt_loss( +def tdt_loss( acts, labels, act_lens, @@ -379,7 +379,7 @@ def tdt_rnnt_loss( # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) - return _TDTRNNTNumba.apply(acts, labels, act_lens, label_lens, blank, durations, reduction, fastemit_lambda, clamp) + return _TDTNumba.apply(acts, labels, act_lens, label_lens, blank, durations, reduction, fastemit_lambda, clamp) class RNNTLossNumba(Module): @@ -499,16 +499,18 @@ def forward(self, acts, labels, act_lens, label_lens): ) -class TDTRNNTLossNumba(Module): +class TDTLossNumba(Module): """ Parameters: blank (int): standard blank label. - durations: list of durations for multi-blank transducer, e.g. - [2, 4, 8]. + durations: list of durations for TDT model, e.g. + [0, 1, 2, 3, 4]. sigma: hyper-parameter for logit under-normalization method for training - multi-blank transducers. Recommended value 0.05. - Refer to https://arxiv.org/pdf/2211.03541 for detailed explanations for + TDT. Recommended value 0.05. + omega: hyper-parameter for RNN-T loss for loss combination. + Refer to https://arxiv.org/abs/2304.06795 for detailed explanations for the above parameters; + reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and @@ -528,13 +530,13 @@ def __init__( sigma: float = 0.0, omega: float = 0.0, ): - super(TDTRNNTLossNumba, self).__init__() + super(TDTLossNumba, self).__init__() self.blank = blank self.durations = durations self.fastemit_lambda = fastemit_lambda self.clamp = float(clamp) if clamp > 0 else 0.0 self.reduction = reduction - self.loss = _TDTRNNTNumba.apply + self.loss = _TDTNumba.apply self.sigma = sigma self.omega = omega diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index e63b399750d8..fa44c6f7c75c 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -523,7 +523,7 @@ def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): return used_offset, (denom, alphas, betas, llForward, llBackward, bigblank_durations) -class TDTGPURNNT(GPURNNT): +class GPUTDT(GPURNNT): def __init__( self, sigma: float, @@ -542,12 +542,12 @@ def __init__( stream, ): """ - Helper class to launch the CUDA Kernels to compute TDT Transducer Loss (https://arxiv.org/pdf/2211.03541). + Helper class to launch the CUDA Kernels to compute TDT Loss (https://arxiv.org/pdf/2211.03541). Args: sigma: Hyper-parameter related to the logit-normalization method in training tdt transducers. omega: Hyper-parameter related to the sampled training. - num_durations: Number of big blank symbols the model has. This should not include the standard blank symbol. + num_durations: Number of durations the model supports. minibatch: Int representing the batch size. maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. @@ -556,7 +556,7 @@ def __init__( blocks used as working memory. tdt_workspace: An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory specifically for the tdt related computations. - blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + blank: Index of the blank token in the vocabulary. Must be the last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. @@ -589,8 +589,10 @@ def compute_cost_and_score( Compute both the loss and the gradients. Args: - acts: A flattened tensor of shape [B, T, U, V+1] representing the activation matrix. - grad: A flattented zero tensor of same shape as acts. + label_acts: A flattened tensor of shape [B, T, U, V] representing the activation matrix for tokens. + duration_acts: A flattened tensor of shape [B, T, U, D] representing the activation matrix for durations. + label_grad: A flattented zero tensor of same shape as label_acts. + duration_grad: A flattented zero tensor of same shape as duration_acts. costs: A zero vector of length B which will be updated inplace with the log probability costs. flat_labels: A flattened matrix of labels of shape [B, U] label_lengths: A vector of length B that contains the original lengths of the acoustic sequence. @@ -598,7 +600,7 @@ def compute_cost_and_score( Updates: This will launch kernels that will update inline the following variables: - - grads: Gradients of the activation matrix wrt the costs vector. + - *_grads: Gradients of the activation matrix wrt the costs vector. - costs: Negative log likelihood of the forward variable. Returns: diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 0f7eb65b2544..5964f85599fd 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -908,9 +908,10 @@ def compute_tdt_alphas_kernel( Compute alpha (forward variable) probabilities over the transduction step. Args: - acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. - denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor - across entire vocabulary. + acts: Tensor of shape [B, T, U, V] flattened. Represents the logprobs activation tensor for tokens. + acts: Tensor of shape [B, T, U, D] flattened. Represents the logprobs activation tensor for duration. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor for tokens. + alphas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable probabilities. llForward: Zero tensor of shape [B]. Represents the log-likelihood of the forward pass. @@ -925,8 +926,7 @@ def compute_tdt_alphas_kernel( maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). - blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. - big_blank_: Index of the RNNT big blank token in the vocabulary. Generally the first or last token in the vocab. + blank_: Index of the TDT blank token in the vocabulary. Must be the last token in the vocab. Updates: Kernel inplace updates the following inputs: diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 939e45bed006..ad98d49b8b37 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2204,7 +2204,7 @@ class GreedyBatchedRNNTInferConfig: confidence_method_cfg: Optional[ConfidenceMethodConfig] = None -class GreedyTDTRNNTInfer(_GreedyRNNTInfer): +class GreedyTDTInfer(_GreedyRNNTInfer): """A greedy transducer decoder. Sequence level greedy decoding, performed auto-repressively. @@ -2456,7 +2456,7 @@ def _greedy_decode( return hypothesis -class GreedyBatchedTDTRNNTInfer(_GreedyRNNTInfer): +class GreedyBatchedTDTInfer(_GreedyRNNTInfer): """A batch level greedy transducer decoder. Batch level greedy decoding, performed auto-repressively. Args: diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 26e44d1fb10e..3003a8c07d2c 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -18,12 +18,12 @@ import pytest import torch -from nemo.collections.asr.losses.rnnt import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTRNNTLossPytorch +from nemo.collections.asr.losses.rnnt import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTLossPytorch from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import ( MultiblankRNNTLossNumba, RNNTLossNumba, - TDTRNNTLossNumba, + TDTLossNumba, ) from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ @@ -498,7 +498,7 @@ def test_case_randomized_act_label(self, device): assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "multi-blank gradient mismatch." -class TestTDTRNNTLoss: +class TestTDTLoss: @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) def test_case_randomized_act_label(self, device): @@ -512,10 +512,10 @@ def test_case_randomized_act_label(self, device): acts = torch.rand([B, T, U, V + 1 + len(durations)]) labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] - fn_pt = TDTRNNTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) - fn_ag = TDTRNNTLossPytorch( + fn_ag = TDTLossPytorch( blank=V, reduction='sum', durations=durations, sigma=sigma ) # ag for automatic gradient computation ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) From c5032b9a5201f3628dc79fa6bddca7ae12da15c1 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 16:37:21 -0400 Subject: [PATCH 11/40] TDT WIP Signed-off-by: Hainan Xu --- .../conformer_tdt_bpe.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/asr/conf/conformer/{conformer_tdt_transducer_bpe.yaml => tdt/conformer_tdt_bpe.yaml} (100%) diff --git a/examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml similarity index 100% rename from examples/asr/conf/conformer/conformer_tdt_transducer_bpe.yaml rename to examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml From 4053e3d9b375ecdd88fae3eeab876b93224cc922 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 16:42:14 -0400 Subject: [PATCH 12/40] TDT WIP Signed-off-by: Hainan Xu --- examples/asr/speech_to_text_eval.py | 2 +- tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index d846157b6513..a0ff02c7996e 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization from nemo.core.config import hydra_runner diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 3003a8c07d2c..d05cd6536112 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -520,8 +520,8 @@ def test_case_randomized_act_label(self, device): ) # ag for automatic gradient computation ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) - assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt-blank costs mismatch." - assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "tdt-blank gradient mismatch." + assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt costs mismatch." + assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." if __name__ == "__main__": From 23251fa2815ef0cb000b27310b4e673944e7ddba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 May 2023 20:43:12 +0000 Subject: [PATCH 13/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index a0ff02c7996e..d846157b6513 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization from nemo.core.config import hydra_runner From 06175ae896278584da496305f80053a46acff6b9 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 16:43:41 -0400 Subject: [PATCH 14/40] TDT WIP Signed-off-by: Hainan Xu --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index a0ff02c7996e..d846157b6513 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization from nemo.core.config import hydra_runner From 26d630781a1aa5a0c42feb3ac8dfbc26400a45fa Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 16:59:26 -0400 Subject: [PATCH 15/40] TDT WIP Signed-off-by: Hainan Xu --- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 8 +++----- .../utils/cuda_utils/gpu_rnnt_kernel.py | 13 ++++++++----- .../asr/parts/submodules/rnnt_greedy_decoding.py | 16 +++++----------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index d69ed9e58984..34fe2ca3e31b 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -196,12 +196,10 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): """ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): - blank_id = tokenizer.tokenizer.vocab_size # RNNT or Multi-blank RNNT. + blank_id = tokenizer.tokenizer.vocab_size # RNNT or TDT models. - # TDT model. - if 'durations' in decoding_cfg: - blank_id = tokenizer.tokenizer.vocab_size - elif 'big_blank_durations' in decoding_cfg: + # multi-blank RNNTs + if 'big_blank_durations' in decoding_cfg: blank_id = tokenizer.tokenizer.vocab_size + joint.num_extra_outputs self.tokenizer = tokenizer diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 5964f85599fd..656afdb1512f 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -909,7 +909,7 @@ def compute_tdt_alphas_kernel( Args: acts: Tensor of shape [B, T, U, V] flattened. Represents the logprobs activation tensor for tokens. - acts: Tensor of shape [B, T, U, D] flattened. Represents the logprobs activation tensor for duration. + duration_acts: Tensor of shape [B, T, U, D] flattened. Represents the logprobs activation tensor for duration. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor for tokens. alphas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable @@ -1066,7 +1066,8 @@ def compute_tdt_betas_kernel( Compute beta (backward variable) probabilities over the transduction step. Args: - acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + acts: Tensor of shape [B, T, U, V] flattened. Represents the logprobs activation tensor for tokens. + duration_acts: Tensor of shape [B, T, U, D] flattened. Represents the logprobs activation tensor for duations. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. betas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the backward variable @@ -1211,9 +1212,11 @@ def compute_tdt_grad_kernel( Compute gradients over the transduction step. Args: - grads: Zero Tensor of shape [B, T, U, V+1]. Is updated by this kernel to contain the gradients - of this batch of samples. - acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + grads: Zero Tensor of shape [B, T, U, V] to store gradients for tokens. + duration_grads: Zero Tensor of shape [B, T, U, D] to store gradients for durations. + + acts: Tensor of shape [B, T, U, V] flattened. Represents the logprobs activation tensor for tokens. + duration_acts: Tensor of shape [B, T, U, D] flattened. Represents the logprobs activation tensor for durations. denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor across entire vocabulary. alphas: Alpha variable, contains forward probabilities. A tensor of shape [B, T, U]. diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index ad98d49b8b37..9e7fb28d22b3 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2205,14 +2205,14 @@ class GreedyBatchedRNNTInferConfig: class GreedyTDTInfer(_GreedyRNNTInfer): - """A greedy transducer decoder. + """A greedy TDT decoder. Sequence level greedy decoding, performed auto-repressively. Args: decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. joint_model: rnnt_utils.AbstractRNNTJoint implementation. - blank_index: int index of the blank token. Must be len(vocabulary) for multi=blank RNNTs. + blank_index: int index of the blank token. Must be len(vocabulary) for TDT models. durations: a list containing durations for TDT. max_symbols_per_step: Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is @@ -2355,9 +2355,6 @@ def _greedy_decode( if self.preserve_frame_confidence: hypothesis.frame_confidence = [[]] - # For timestep t in X_t - duration = 1 - time_idx = 0 while time_idx < out_len: # Extract encoder embedding at timestep t @@ -2457,12 +2454,12 @@ def _greedy_decode( class GreedyBatchedTDTInfer(_GreedyRNNTInfer): - """A batch level greedy transducer decoder. + """A batch level greedy TDT decoder. Batch level greedy decoding, performed auto-repressively. Args: decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. joint_model: rnnt_utils.AbstractRNNTJoint implementation. - blank_index: int index of the blank token. Must be len(vocabulary) for multi-blank RNNTs. + blank_index: int index of the blank token. Must be len(vocabulary) for TDT models. durations: a list containing durations. max_symbols_per_step: Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is @@ -2626,9 +2623,6 @@ def _greedy_decode_blank_as_pad( # Mask buffers blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) - # mask for if the utterance in the batch should stay in the same frame. - # stay_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) - # Get max sequence length max_out_len = out_len.max() @@ -2779,7 +2773,7 @@ def _greedy_decode_masked( for hyp in hypotheses: hyp.alignments = [[]] else: - alignments = None + hyp.alignments = None # If confidence scores need to be preserved, register a danling list to hold the values if self.preserve_frame_confidence: From 656ea9ef42d6ce9edcc60445b98cfb16a0e681d7 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 5 May 2023 17:03:06 -0400 Subject: [PATCH 16/40] TDT WIP Signed-off-by: Hainan Xu --- nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 12f49eb023f5..00111efe8fee 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -548,6 +548,8 @@ def forward(self, acts, labels, act_lens, label_lens): label_lens: Tensor of (batch) containing label length of each example """ + # TODO(hainan): in the future, we could further optimize this so that we don't need to + # make copies of the acts tensor. label_acts = acts[:, :, :, : -len(self.durations)].contiguous() duration_acts = torch.nn.functional.log_softmax(acts[:, :, :, -len(self.durations) :], dim=-1).contiguous() From 1e55f9267410188ac0060a1cd94e0f3296106189 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Tue, 9 May 2023 17:27:15 -0400 Subject: [PATCH 17/40] TDT WIP Signed-off-by: Hainan Xu --- nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 9e7fb28d22b3..927639f9a244 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2672,7 +2672,6 @@ def _greedy_decode_blank_as_pad( skip = self.durations[int(torch.min(dk))] if blank_mask.all(): - # print("SKIP is", skip) if skip == 0: skip = 1 need_to_stay = skip == 0 From 31100defe66f99568e06e32d5683aa3602e5254c Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 10 May 2023 19:02:29 -0400 Subject: [PATCH 18/40] TDT WIP Signed-off-by: Hainan Xu --- nemo/collections/asr/metrics/rnnt_wer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index b4bfa479988d..b48cba7d69e9 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -211,7 +211,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.compute_timestamps = self.cfg.get('compute_timestamps', None) self.word_seperator = self.cfg.get('word_seperator', ' ') - if self.durations is not None: + if self.durations is not None: # this means it's a TDT model. if blank_id == 0: raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") if self.big_blank_durations is not None: @@ -219,7 +219,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.cfg.strategy not in ['greedy', 'greedy_batch']: raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models") - if self.big_blank_durations is not None: + if self.big_blank_durations is not None: # this means it's a multi-blank model. if blank_id == 0: raise ValueError("blank_id must equal len(vocabs) for multi-blank RNN-T models") if self.cfg.strategy not in ['greedy', 'greedy_batch']: From 59a2a5264d559e3f64b59ac10f85e396abd54aae Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 11 May 2023 13:13:57 -0400 Subject: [PATCH 19/40] addressed some review comments, part1 Signed-off-by: Hainan Xu --- .../conf/conformer/tdt/conformer_tdt_bpe.yaml | 52 +++++++++++++++++-- nemo/collections/asr/losses/rnnt_pytorch.py | 20 +++++-- nemo/collections/asr/metrics/rnnt_wer.py | 18 +++---- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 23 +++++--- .../rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 36 +------------ .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 2 +- .../asr/test_asr_rnnt_encdec_model.py | 3 -- 7 files changed, 93 insertions(+), 61 deletions(-) diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml index 0bfb39771b50..cb67de7224f1 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml @@ -2,6 +2,47 @@ # You can find detailed info about TDT models at https://arxiv.org/abs/2304.06795. +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+------------------+--------------+--------------------------+-----------------+ +# | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | pred_rnn_layers | +# +==============+=========+========+===========+==================+==============+==========================+=================+ +# | Small (14M)| 176 | 4 | 16 | 31 | 0.0 | 320 | 1 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ +# | Medium (32M)| 256 | 4 | 16 | 31 | 1e-3 | 640 | 1 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ +# | Large (120M)| 512 | 8 | 17 | 31 | 1e-3 | 640 | 1 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ +# | XLarge (644M)| 1024 | 8 | 24 | 5 | 1e-3 | 640 | 2 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ + +# Default learning parameters in this config are set for global batch size of 2K while you may use lower values. +# To increase the global batch size with limited number of GPUs, you may use higher accumulate_grad_batches. +# However accumulate_grad_batches is better to be avoided as long as the global batch size is large enough and training is stable. + +# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer +# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html +# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large + +# We suggest to use trainer.precision=bf16 for GPUs which support it otherwise trainer.precision=16 is recommended. +# Using bf16 or 16 would make it possible to double the batch size and speedup training/inference. If fp16 is not stable and model diverges after some epochs, you may use fp32. +# Here are the suggested batch size per GPU for each precision and memory sizes: +# +-----------+------------+------------+ +# | Precision | GPU Memory | Batch Size | +# +===========+============+============+ +# | 32 | 16GB | 8 | +# | | 32GB | 16 | +# | | 80GB | 32 | +# +-----------+------------+------------+ +# | 16 or | 16GB | 16 | +# | bf16 | 32GB | 32 | +# | | 80GB | 64 | +# +-----------+------------+------------+ +# Note: They are based on the assumption of max_duration of 20. If you have longer or shorter max_duration, then batch sizes may need to get updated accordingly. + name: "Conformer-TDT-BPE" model: @@ -15,6 +56,11 @@ model: pred_hidden: 640 joint_hidden: 640 + # variables for TDT configs. + tdt_durations: [0, 1, 2, 3, 4] + num_tdt_durations: 5 + + train_ds: manifest_filepath: ??? sample_rate: ${model.sample_rate} @@ -147,13 +193,13 @@ model: joint_hidden: ${model.model_defaults.joint_hidden} activation: "relu" dropout: 0.2 - num_extra_outputs: 5 + num_extra_outputs: ${model.model_defaults.num_tdt_durations} decoding: strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. # this must not be None in order to use the TDT specific decoding method. - durations: [0, 1, 2, 3, 4] + durations: ${model.model_defaults.tdt_durations} # greedy strategy config greedy: @@ -178,7 +224,7 @@ model: clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. # refer to https://arxiv.org/abs/2304.06795 for the meaning of the following three configs. - durations: [0, 1, 2, 3, 4] + durations: ${model.model_defaults.tdt_durations} sigma: 0.05 # hyper-param for under-normalization. omega: 0.1 # weight for regular RNN-T loss. diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index 45dedc972da1..ffaa7c17a02d 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +from typing import List from nemo.core.classes import Loss from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType @@ -113,6 +114,9 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): class TDTLossPytorch(Loss): + """ + Pure Python implementation of TDT loss (https://arxiv.org/pdf/2304.06795.pdf) + """ @property def input_types(self): """Input types definitions for CTCLoss. @@ -132,7 +136,7 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, blank, durations, reduction, sigma): + def __init__(self, blank: int, durations: List[int]=[], reduction: str='sum', sigma: float=0.0): super().__init__() self.blank = blank self.durations = durations @@ -144,6 +148,7 @@ def forward(self, acts, labels, act_lens, label_lens): label_acts = acts[:, :, :, : -self.n_durations] duration_acts = acts[:, :, :, -self.n_durations :] + # the - self.sigma here is for logit-undernormalization. Check the paper for details. label_acts = torch.log_softmax(label_acts, -1) - self.sigma duration_acts = torch.log_softmax(duration_acts, -1) @@ -166,6 +171,9 @@ def logsumexp(self, a, b): return ret def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens): + """This function implements Equation 7 in the TDT paper https://arxiv.org/pdf/2304.06795.pdf, + Simply put, for each alpha(t, u), it sums over the contribution from all incoming blank arcs and non-blank arcs. + """ B, T, U, _ = acts.shape log_alpha = torch.zeros(B, T, U) @@ -175,11 +183,13 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens for u in range(U): if u == 0: if t == 0: + # both t and u are 0, this is the base case for alphas. log_alpha[b, t, u] = 0.0 else: + # u = 0 and t != 0: only considers blank emissions. log_alpha[b, t, u] = -1000.0 for n, l in enumerate(self.durations): - if t - l >= 0 and l > 0: # blank emission, l has to be at least 1 + if t - l >= 0 and l > 0: # checking conditions for blank emission, l has to be at least 1 tmp = ( log_alpha[b, t - l, u] + acts[b, t - l, u, self.blank] @@ -188,10 +198,11 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) else: + # u != 0 here, need to consider both blanks and non-blanks. log_alpha[b, t, u] = -1000.0 for n, l in enumerate(self.durations): if t - l >= 0: - if l > 0: + if l > 0: # for blank emissions. Need to ensure index is not out-of-bound. tmp = ( log_alpha[b, t - l, u] + acts[b, t - l, u, self.blank] @@ -199,6 +210,7 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens ) log_alpha[b, t, u] = self.logsumexp(tmp, 1.0 * log_alpha[b, t, u]) + # non-blank emissions. tmp = ( log_alpha[b, t - l, u - 1] + acts[b, t - l, u - 1, labels[b, u - 1]] @@ -209,6 +221,8 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens log_probs = [] for b in range(B): tt = torch.Tensor([-1000.0]).cuda()[0] + + # need to loop over all possible ways that blank with different durations contributes to the final loss. for n, l in enumerate(self.durations): if act_lens[b] - l >= 0 and l > 0: bb = ( diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index b48cba7d69e9..fff87de8e47e 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -266,12 +266,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.cfg.strategy == 'greedy': if self.big_blank_durations is None: - if self.durations is not None: - self.decoding = greedy_decode.GreedyTDTInfer( + if self.durations is None: + self.decoding = greedy_decode.GreedyRNNTInfer( decoder_model=decoder, joint_model=joint, blank_index=self.blank_id, - durations=self.durations, max_symbols_per_step=( self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) @@ -281,10 +280,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): confidence_method_cfg=self.confidence_method_cfg, ) else: - self.decoding = greedy_decode.GreedyRNNTInfer( + self.decoding = greedy_decode.GreedyTDTInfer( decoder_model=decoder, joint_model=joint, blank_index=self.blank_id, + durations=self.durations, max_symbols_per_step=( self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) @@ -309,12 +309,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): elif self.cfg.strategy == 'greedy_batch': if self.big_blank_durations is None: - if self.durations is not None: - self.decoding = greedy_decode.GreedyBatchedTDTInfer( + if self.durations is None: + self.decoding = greedy_decode.GreedyBatchedRNNTInfer( decoder_model=decoder, joint_model=joint, blank_index=self.blank_id, - durations=self.durations, max_symbols_per_step=( self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) @@ -324,10 +323,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): confidence_method_cfg=self.confidence_method_cfg, ) else: - self.decoding = greedy_decode.GreedyBatchedRNNTInfer( + self.decoding = greedy_decode.GreedyBatchedTDTInfer( decoder_model=decoder, joint_model=joint, blank_index=self.blank_id, + durations=self.durations, max_symbols_per_step=( self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None) @@ -1102,7 +1102,7 @@ def __init__( # we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT. blank_id = len(vocabulary) + joint.num_extra_outputs - if 'durations' in decoding_cfg: # this means it's a TDT model. + if 'durations' in decoding_cfg and decoding_cfg['durations'] is not None: # this means it's a TDT model. blank_id = len(vocabulary) self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 00111efe8fee..8a78d7a7687b 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -93,7 +93,7 @@ def backward(ctx, grad_output): class _TDTNumba(Function): """ - Numba class for TDT loss (https://arxiv.org/abs/2304.06795) + Numba class for Token-and-Duration Transducer (TDT) loss (https://arxiv.org/abs/2304.06795) """ @staticmethod @@ -113,14 +113,20 @@ def forward( omega, ): """ + log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + durations: list of durations for TDT model, must include 0 and 1, e.g. [0, 1, 2, 3, 4]. sigma: hyper-parameter for logit under-normalization method for training TDT models. Recommended value 0.05. - omega: weight for standard RNN-T loss + omega: probability for sampling the standard RNN-T loss Refer to https://arxiv.org/abs/2304.06795 for detailed explanations for the above parameters; - For other parameters for this class, refer to comment for class _RNNTNumba """ is_cuda = label_acts.is_cuda @@ -523,7 +529,7 @@ class TDTLossNumba(Module): def __init__( self, blank, - durations=[], + durations=None, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = -1, @@ -532,7 +538,7 @@ def __init__( ): super(TDTLossNumba, self).__init__() self.blank = blank - self.durations = durations + self.durations = durations if durations is not None else [] self.fastemit_lambda = fastemit_lambda self.clamp = float(clamp) if clamp > 0 else 0.0 self.reduction = reduction @@ -549,9 +555,10 @@ def forward(self, acts, labels, act_lens, label_lens): """ # TODO(hainan): in the future, we could further optimize this so that we don't need to - # make copies of the acts tensor. - label_acts = acts[:, :, :, : -len(self.durations)].contiguous() - duration_acts = torch.nn.functional.log_softmax(acts[:, :, :, -len(self.durations) :], dim=-1).contiguous() + # make contiguous copies of the acts tensor. + label_acts, duration_acts = torch.split(acts, [acts.shape[-1] - len(self.durations), len(self.durations)], dim=-1) + label_acts = label_acts.contiguous() + duration_acts = torch.nn.functional.log_softmax(duration_acts, dim=-1).contiguous() return self.loss( label_acts, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index fa44c6f7c75c..70ffb459cb97 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -500,23 +500,7 @@ def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): An int, representing the offset of the used workspace (practically, the slice of the workspace consumed) A tuple of tensors representing the shared workspace. """ - used_offset = 0 - - # // denom - denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] - used_offset += self.maxT_ * self.maxU_ * self.minibatch_ - - # // alphas & betas - alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] - used_offset += self.maxT_ * self.maxU_ * self.minibatch_ - betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] - used_offset += self.maxT_ * self.maxU_ * self.minibatch_ - - # // logllh - llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] - used_offset += self.minibatch_ - llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] - used_offset += self.minibatch_ + used_offset, (denom, alphas, betas, llForward, llBackward) = super()._prepare_workspace() bigblank_durations = self.big_blank_workspace[: self.num_big_blanks] @@ -814,23 +798,7 @@ def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): An int, representing the offset of the used workspace (practically, the slice of the workspace consumed) A tuple of tensors representing the shared workspace. """ - used_offset = 0 - - # // denom - denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] - used_offset += self.maxT_ * self.maxU_ * self.minibatch_ - - # // alphas & betas - alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] - used_offset += self.maxT_ * self.maxU_ * self.minibatch_ - betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_] - used_offset += self.maxT_ * self.maxU_ * self.minibatch_ - - # // logllh - llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] - used_offset += self.minibatch_ - llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] - used_offset += self.minibatch_ + used_offset, (denom, alphas, betas, llForward, llBackward) = super()._prepare_workspace() durations = self.tdt_workspace[: self.num_durations] diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index d05cd6536112..e2a8f83ded7f 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -503,7 +503,7 @@ class TestTDTLoss: @pytest.mark.parametrize('device', DEVICES) def test_case_randomized_act_label(self, device): if device == 'cuda': - numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) +# numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) B, T, U, V = 4, 8, 4, 8 # here V is number of non blank labels durations = [0, 1, 2, 3, 4, 5] diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 173e846e91f9..68f1e38f797b 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -363,9 +363,6 @@ def test_multiblank_rnnt_greedy_decoding(self, greedy_class): with torch.no_grad(): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) - @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', - ) @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) From fe52c954f23af7dd2d2c8fc3dcbc9631f095b894 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 11 May 2023 13:14:23 -0400 Subject: [PATCH 20/40] addressed some review comments, part1, one line fix Signed-off-by: Hainan Xu --- tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index e2a8f83ded7f..d05cd6536112 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -503,7 +503,7 @@ class TestTDTLoss: @pytest.mark.parametrize('device', DEVICES) def test_case_randomized_act_label(self, device): if device == 'cuda': -# numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) B, T, U, V = 4, 8, 4, 8 # here V is number of non blank labels durations = [0, 1, 2, 3, 4, 5] From 4b0036525a1eeee9a6cd9a39701f7e56fcf74007 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 May 2023 17:15:30 +0000 Subject: [PATCH 21/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/asr/losses/rnnt_pytorch.py | 10 +++++++--- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 4 +++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index ffaa7c17a02d..12040b42e419 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from typing import List +import torch + from nemo.core.classes import Loss from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType @@ -117,6 +118,7 @@ class TDTLossPytorch(Loss): """ Pure Python implementation of TDT loss (https://arxiv.org/pdf/2304.06795.pdf) """ + @property def input_types(self): """Input types definitions for CTCLoss. @@ -136,7 +138,7 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, blank: int, durations: List[int]=[], reduction: str='sum', sigma: float=0.0): + def __init__(self, blank: int, durations: List[int] = [], reduction: str = 'sum', sigma: float = 0.0): super().__init__() self.blank = blank self.durations = durations @@ -189,7 +191,9 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens # u = 0 and t != 0: only considers blank emissions. log_alpha[b, t, u] = -1000.0 for n, l in enumerate(self.durations): - if t - l >= 0 and l > 0: # checking conditions for blank emission, l has to be at least 1 + if ( + t - l >= 0 and l > 0 + ): # checking conditions for blank emission, l has to be at least 1 tmp = ( log_alpha[b, t - l, u] + acts[b, t - l, u, self.blank] diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 8a78d7a7687b..c4351271bd62 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -556,7 +556,9 @@ def forward(self, acts, labels, act_lens, label_lens): # TODO(hainan): in the future, we could further optimize this so that we don't need to # make contiguous copies of the acts tensor. - label_acts, duration_acts = torch.split(acts, [acts.shape[-1] - len(self.durations), len(self.durations)], dim=-1) + label_acts, duration_acts = torch.split( + acts, [acts.shape[-1] - len(self.durations), len(self.durations)], dim=-1 + ) label_acts = label_acts.contiguous() duration_acts = torch.nn.functional.log_softmax(duration_acts, dim=-1).contiguous() From 462f69edf98c951e9a3e4b161dd3220fafb8eb04 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Tue, 16 May 2023 21:26:40 -0400 Subject: [PATCH 22/40] add tests for comparing TDT alphas with pytorch VS kernel computation Signed-off-by: Hainan Xu --- .../conf/conformer/tdt/conformer_tdt_bpe.yaml | 6 +- examples/asr/speech_to_text_eval.py | 2 +- nemo/collections/asr/losses/rnnt_pytorch.py | 18 ++-- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 4 +- .../parts/submodules/rnnt_greedy_decoding.py | 6 +- .../rnnt_loss/utils/test_gpu_rnnt_kernel.py | 97 +++++++++++++++++++ 6 files changed, 121 insertions(+), 12 deletions(-) diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml index cb67de7224f1..427a7af463a9 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml @@ -193,10 +193,12 @@ model: joint_hidden: ${model.model_defaults.joint_hidden} activation: "relu" dropout: 0.2 - num_extra_outputs: ${model.model_defaults.num_tdt_durations} + num_extra_outputs: ${model.model_defaults.num_tdt_durations} decoding: - strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + # Using greedy decoding is highly recommended for TDT models. Using greedy-batch will give very bad results + # if omega is 0; even if omega is non-zero, greedy-batch results are still going to be inaccurate. + strategy: "greedy" # this must not be None in order to use the TDT specific decoding method. durations: ${model.model_defaults.tdt_durations} diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index f8dcbcf81bbd..ebc1b9b259a8 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index ffaa7c17a02d..3607cddd18bc 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from typing import List +import torch + from nemo.core.classes import Loss from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType @@ -117,6 +118,7 @@ class TDTLossPytorch(Loss): """ Pure Python implementation of TDT loss (https://arxiv.org/pdf/2304.06795.pdf) """ + @property def input_types(self): """Input types definitions for CTCLoss. @@ -136,7 +138,7 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, blank: int, durations: List[int]=[], reduction: str='sum', sigma: float=0.0): + def __init__(self, blank: int, durations: List[int] = [], reduction: str = 'sum', sigma: float = 0.0): super().__init__() self.blank = blank self.durations = durations @@ -153,7 +155,7 @@ def forward(self, acts, labels, act_lens, label_lens): duration_acts = torch.log_softmax(duration_acts, -1) - forward_logprob = self.compute_forward_prob(label_acts, duration_acts, labels, act_lens, label_lens) + forward_logprob, _ = self.compute_forward_prob(label_acts, duration_acts, labels, act_lens, label_lens) losses = -forward_logprob if self.reduction == 'mean_batch': losses = losses.mean() # global batch size average @@ -189,7 +191,9 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens # u = 0 and t != 0: only considers blank emissions. log_alpha[b, t, u] = -1000.0 for n, l in enumerate(self.durations): - if t - l >= 0 and l > 0: # checking conditions for blank emission, l has to be at least 1 + if ( + t - l >= 0 and l > 0 + ): # checking conditions for blank emission, l has to be at least 1 tmp = ( log_alpha[b, t - l, u] + acts[b, t - l, u, self.blank] @@ -237,7 +241,7 @@ def compute_forward_prob(self, acts, duration_acts, labels, act_lens, label_lens log_prob = torch.stack(log_probs) - return log_prob + return log_prob, log_alpha class MultiblankRNNTLossPytorch(Loss): @@ -273,7 +277,7 @@ def __init__(self, blank, big_blank_durations, reduction, sigma): def forward(self, acts, labels, act_lens, label_lens): acts = torch.log_softmax(acts, -1) - self.sigma - forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens) + forward_logprob, _ = self.compute_forward_prob(acts, labels, act_lens, label_lens) losses = -forward_logprob if self.reduction == 'mean_batch': @@ -362,4 +366,4 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): log_probs.append(to_append) log_prob = torch.stack(log_probs) - return log_prob + return log_prob, log_alpha diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 8a78d7a7687b..c4351271bd62 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -556,7 +556,9 @@ def forward(self, acts, labels, act_lens, label_lens): # TODO(hainan): in the future, we could further optimize this so that we don't need to # make contiguous copies of the acts tensor. - label_acts, duration_acts = torch.split(acts, [acts.shape[-1] - len(self.durations), len(self.durations)], dim=-1) + label_acts, duration_acts = torch.split( + acts, [acts.shape[-1] - len(self.durations), len(self.durations)], dim=-1 + ) label_acts = label_acts.contiguous() duration_acts = torch.nn.functional.log_softmax(duration_acts, dim=-1).contiguous() diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 927639f9a244..b1df2d360405 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2383,7 +2383,7 @@ def _greedy_decode( if self.preserve_frame_confidence: logp = torch.log_softmax(logp, -1) - duration_logp = torch.softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) del g # torch.max(0) op doesnt exist for FP 16. @@ -2412,6 +2412,10 @@ def _greedy_decode( # If blank token is predicted, exit inner loop, move onto next timestep t if k == self._blank_index: not_blank = False + + # this rarely happens, but we manually increment the `skip` number + # if blank is emitted and duration=0 is predicted. This prevents possible + # infinite loops. if skip == 0: skip = 1 diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py index acab5963fa72..8ff2fffcbdbb 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py @@ -17,6 +17,7 @@ import torch from numba import cuda +from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, TDTLossPytorch from nemo.collections.asr.parts.numba.rnnt_loss import rnnt_numpy from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import certify_inputs from nemo.collections.asr.parts.numba.rnnt_loss.utils.cuda_utils import gpu_rnnt_kernel, reduce @@ -504,3 +505,99 @@ def test_compute_grads_kernel_clamp(self): assert np.abs(diff).mean() <= 1e-5 assert np.square(diff).mean() <= 1e-10 + + +class TestTDTCUDAKernels: + @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") + @pytest.mark.unit + def test_compute_alphas_kernel(self): + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + random = np.random.RandomState(0) + original_shape = [1, 15, 11, 3] + durations = [0, 1, 2] + B, T, U, V = original_shape + Vd = len(durations) + + duration_act_shape = [B, T, U, Vd] + sigma = 0.0 + + # for passing into the kernel function -- it expected unnormalized logits + x = random.randn(*original_shape) + # for passing into the pytorch function -- it expected normalized logits + normalized_x = log_softmax(x, axis=-1) + + xd = random.randn(*duration_act_shape) + # duration logits are normalized before passing into the loss computation. + xd = log_softmax(xd, axis=-1) + + labels = np.array([[1, 1, 1, 1, 0, 0, 1, 0, 0, 1]]) # [1, 10] + blank_idx = V - 1 + + pytorch_tdt_loss = TDTLossPytorch(blank_idx, durations, sigma=sigma) + + # Pytorch kernel + device = torch.device('cuda') + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(device).cuda_stream) + else: + stream = cuda.default_stream() + + x = torch.tensor(x, device=device, dtype=torch.float32) + xd = torch.tensor(xd, device=device, dtype=torch.float32) + labels = torch.tensor(labels, device=device, dtype=torch.long) + durations = torch.tensor(durations, device=device, dtype=torch.long) + + # Allocate workspace memory + denom = torch.zeros(B * T * U, device=device, dtype=x.dtype) + alphas = torch.zeros(B * T * U, device=device, dtype=x.dtype) + llForward = torch.zeros(B, device=device, dtype=x.dtype) + input_lengths = torch.tensor([T], dtype=torch.long, device=device) + label_lengths = torch.tensor([U - 1], dtype=torch.long, device=device) + + ground_log_likelihood, ground_alphas = pytorch_tdt_loss.compute_forward_prob( + normalized_x, xd, labels, input_lengths, label_lengths + ) + + # certify input data + certify_inputs(x, labels, input_lengths, label_lengths) + + # flatten activation tensor (for pointer based indexing) + x = x.view([-1]) + xd = xd.view([-1]) + + # call kernel + # log softmax reduction + reduce.reduce_max(x, denom, rows=V, cols=B * T * U, minus=False, stream=stream) + reduce.reduce_exp(x, denom, rows=V, cols=B * T * U, minus=True, stream=stream) + + # alpha kernel + gpu_rnnt_kernel.compute_tdt_alphas_kernel[B, U, stream, 0]( + x, + xd, + denom, + sigma, + alphas, + llForward, + input_lengths, + label_lengths, + labels, + B, + T, + U, + V, + blank_idx, + durations, + Vd, + ) + + # sync kernel + stream.synchronize() + + # reshape alphas + alphas = alphas.view([B, T, U]) + diff = torch.norm(ground_alphas - alphas) + ll_diff = torch.norm(ground_log_likelihood - llForward) + + assert diff <= 1e-3 + assert ll_diff <= 1e-3 From 45e025ccba042ac3efbcef49434de4c5743a5aa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 May 2023 01:28:27 +0000 Subject: [PATCH 23/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index ebc1b9b259a8..f8dcbcf81bbd 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner From 7d29d445d1807326bc2548bf9892fd1591f51c2d Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Tue, 16 May 2023 21:59:08 -0400 Subject: [PATCH 24/40] add tests for comparing multiblank alphas with pytorch VS kernel computation Signed-off-by: Hainan Xu --- examples/asr/speech_to_text_eval.py | 2 +- nemo/collections/asr/losses/rnnt_pytorch.py | 2 +- .../rnnt_loss/utils/test_gpu_rnnt_kernel.py | 94 ++++++++++++++++++- 3 files changed, 94 insertions(+), 4 deletions(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index f8dcbcf81bbd..ebc1b9b259a8 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index 3607cddd18bc..bc6e5a25a3b2 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -268,7 +268,7 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, blank, big_blank_durations, reduction, sigma): + def __init__(self, blank, big_blank_durations, reduction: str = "sum", sigma: float = 0.0): super().__init__() self.blank = blank self.big_blank_durations = big_blank_durations diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py index 8ff2fffcbdbb..230b6b7c099f 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py @@ -520,12 +520,12 @@ def test_compute_alphas_kernel(self): Vd = len(durations) duration_act_shape = [B, T, U, Vd] - sigma = 0.0 + sigma = 0.05 # for passing into the kernel function -- it expected unnormalized logits x = random.randn(*original_shape) # for passing into the pytorch function -- it expected normalized logits - normalized_x = log_softmax(x, axis=-1) + normalized_x = log_softmax(x, axis=-1) - 0.05 xd = random.randn(*duration_act_shape) # duration logits are normalized before passing into the loss computation. @@ -544,6 +544,7 @@ def test_compute_alphas_kernel(self): stream = cuda.default_stream() x = torch.tensor(x, device=device, dtype=torch.float32) + normalized_x = torch.tensor(normalized_x, device=device, dtype=torch.float32) xd = torch.tensor(xd, device=device, dtype=torch.float32) labels = torch.tensor(labels, device=device, dtype=torch.long) durations = torch.tensor(durations, device=device, dtype=torch.long) @@ -601,3 +602,92 @@ def test_compute_alphas_kernel(self): assert diff <= 1e-3 assert ll_diff <= 1e-3 + + +class TestMultiblankRNNTCUDAKernels: + @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") + @pytest.mark.unit + def test_compute_alphas_kernel(self): + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + random = np.random.RandomState(0) + original_shape = [1, 15, 11, 6] + big_blank_durations = [2, 3, 4] + B, T, U, V = original_shape + num_big_blanks = len(big_blank_durations) + + sigma = 0.05 + + # for passing into the kernel function -- it expected unnormalized logits + x = random.randn(*original_shape) + # for passing into the pytorch function -- it expected normalized logits + normalized_x = log_softmax(x, axis=-1) - sigma + + labels = np.array([[1, 1, 1, 1, 0, 0, 1, 0, 0, 1]]) # [1, 10] + blank_idx = V - 1 + + pytorch_multiblank_loss = MultiblankRNNTLossPytorch(blank_idx, big_blank_durations, sigma=sigma) + + # Pytorch kernel + device = torch.device('cuda') + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(device).cuda_stream) + else: + stream = cuda.default_stream() + + x = torch.tensor(x, device=device, dtype=torch.float32) + normalized_x = torch.tensor(normalized_x, device=device, dtype=torch.float32) + labels = torch.tensor(labels, device=device, dtype=torch.long) + big_blank_durations = torch.tensor(big_blank_durations, device=device, dtype=torch.long) + + # Allocate workspace memory + denom = torch.zeros(B * T * U, device=device, dtype=x.dtype) + alphas = torch.zeros(B * T * U, device=device, dtype=x.dtype) + llForward = torch.zeros(B, device=device, dtype=x.dtype) + input_lengths = torch.tensor([T], dtype=torch.long, device=device) + label_lengths = torch.tensor([U - 1], dtype=torch.long, device=device) + + ground_log_likelihood, ground_alphas = pytorch_multiblank_loss.compute_forward_prob( + normalized_x, labels, input_lengths, label_lengths + ) + + # certify input data + certify_inputs(x, labels, input_lengths, label_lengths) + + # flatten activation tensor (for pointer based indexing) + x = x.view([-1]) + + # call kernel + # log softmax reduction + reduce.reduce_max(x, denom, rows=V, cols=B * T * U, minus=False, stream=stream) + reduce.reduce_exp(x, denom, rows=V, cols=B * T * U, minus=True, stream=stream) + + # alpha kernel + gpu_rnnt_kernel.compute_multiblank_alphas_kernel[B, U, stream, 0]( + x, + denom, + sigma, + alphas, + llForward, + input_lengths, + label_lengths, + labels, + B, + T, + U, + V, + blank_idx, + big_blank_durations, + num_big_blanks, + ) + + # sync kernel + stream.synchronize() + + # reshape alphas + alphas = alphas.view([B, T, U]) + diff = torch.norm(ground_alphas - alphas) + ll_diff = torch.norm(ground_log_likelihood - llForward) + + assert diff <= 1e-3 + assert ll_diff <= 1e-3 From aff4318e81b8626d96e72df092528245ddb5806f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 May 2023 02:00:15 +0000 Subject: [PATCH 25/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index ebc1b9b259a8..f8dcbcf81bbd 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner From 68ca68c29634cc1978678a5b6e2160cf4a15be1f Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 17 May 2023 15:32:33 -0400 Subject: [PATCH 26/40] add tests for fixed case computation for TDT Signed-off-by: Hainan Xu --- examples/asr/speech_to_text_eval.py | 2 +- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 2 +- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 48 +++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index f8dcbcf81bbd..ebc1b9b259a8 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index c4351271bd62..2ffe08be361e 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -124,7 +124,7 @@ def forward( [0, 1, 2, 3, 4]. sigma: hyper-parameter for logit under-normalization method for training TDT models. Recommended value 0.05. - omega: probability for sampling the standard RNN-T loss + omega: probability for sampling the standard RNN-T loss. Refer to https://arxiv.org/abs/2304.06795 for detailed explanations for the above parameters; """ diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index d05cd6536112..de1e66c641f4 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -524,5 +524,53 @@ def test_case_randomized_act_label(self, device): assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_fixed_case_act_label(self, device): + if device == 'cuda': +# numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + B, T, U, V = 1, 3, 2, 3 # here V is number of non blank labels + durations = [0, 1, 2] + sigma = 0.05 + + acts = torch.zeros([B, T, U, V + 1 + len(durations)]) + labels = [[(i + j) % (V - 1) for i in range(U - 1)] for j in range(B)] + + fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + +# fn_ag = TDTLossPytorch( +# blank=V, reduction='sum', durations=durations, sigma=sigma +# ) # ag for automatic gradient computation +# ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) +# +# f = open('/tmp/1.txt', 'w') +# ag_cost = np.array2string(ag_cost, separator=',') +# ag_grads = np.array2string(ag_grads, separator=',') +# print(ag_cost, file=f) +# print(ag_grads, file=f) + + expected_cost = 4.155739 + expected_grads =[[[[-0.64962804, 0.25 , 0.25 , 0.14962798, 0.2672583 , + -0.16792619,-0.09933221], + [ 0.01651875, 0.01651875, 0.01651875,-0.04955626, 0.022025 , + -0.01227201,-0.009753 ]], + + [[-0.04892651, 0.01714851, 0.01714851, 0.01462949,-0.01143234, + -0.01143234, 0.02286467], + [ 0.12531489, 0.12531489, 0.12531489,-0.37594467, 0.16708651, + 0.13027048,-0.29735702]], + + [[-0.02572276, 0.00857425, 0.00857425, 0.00857425,-0.02286468, + 0.01143234, 0.01143234], + [ 0.13388914, 0.13388914, 0.13388914,-0.40166742, 0.17851885, + -0.35703772, 0.17851885]]]] + + assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "tdt costs mismatch." + assert np.allclose(pt_grads, expected_grads, rtol=1e-2), "td gradient mismatch." + + + if __name__ == "__main__": pytest.main([__file__]) From 8fe52cf076258a798013ef0869901c8a18ae4388 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 May 2023 19:35:52 +0000 Subject: [PATCH 27/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 54 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index 0713d5cc5e36..f4d2a66ffec0 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index de1e66c641f4..4fb15a2f7707 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -523,12 +523,11 @@ def test_case_randomized_act_label(self, device): assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt costs mismatch." assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." - @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) def test_case_fixed_case_act_label(self, device): if device == 'cuda': -# numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + # numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) B, T, U, V = 1, 3, 2, 3 # here V is number of non blank labels durations = [0, 1, 2] @@ -540,37 +539,38 @@ def test_case_fixed_case_act_label(self, device): fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) -# fn_ag = TDTLossPytorch( -# blank=V, reduction='sum', durations=durations, sigma=sigma -# ) # ag for automatic gradient computation -# ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) -# -# f = open('/tmp/1.txt', 'w') -# ag_cost = np.array2string(ag_cost, separator=',') -# ag_grads = np.array2string(ag_grads, separator=',') -# print(ag_cost, file=f) -# print(ag_grads, file=f) + # fn_ag = TDTLossPytorch( + # blank=V, reduction='sum', durations=durations, sigma=sigma + # ) # ag for automatic gradient computation + # ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + # + # f = open('/tmp/1.txt', 'w') + # ag_cost = np.array2string(ag_cost, separator=',') + # ag_grads = np.array2string(ag_grads, separator=',') + # print(ag_cost, file=f) + # print(ag_grads, file=f) expected_cost = 4.155739 - expected_grads =[[[[-0.64962804, 0.25 , 0.25 , 0.14962798, 0.2672583 , - -0.16792619,-0.09933221], - [ 0.01651875, 0.01651875, 0.01651875,-0.04955626, 0.022025 , - -0.01227201,-0.009753 ]], - - [[-0.04892651, 0.01714851, 0.01714851, 0.01462949,-0.01143234, - -0.01143234, 0.02286467], - [ 0.12531489, 0.12531489, 0.12531489,-0.37594467, 0.16708651, - 0.13027048,-0.29735702]], - - [[-0.02572276, 0.00857425, 0.00857425, 0.00857425,-0.02286468, - 0.01143234, 0.01143234], - [ 0.13388914, 0.13388914, 0.13388914,-0.40166742, 0.17851885, - -0.35703772, 0.17851885]]]] + expected_grads = [ + [ + [ + [-0.64962804, 0.25, 0.25, 0.14962798, 0.2672583, -0.16792619, -0.09933221], + [0.01651875, 0.01651875, 0.01651875, -0.04955626, 0.022025, -0.01227201, -0.009753], + ], + [ + [-0.04892651, 0.01714851, 0.01714851, 0.01462949, -0.01143234, -0.01143234, 0.02286467], + [0.12531489, 0.12531489, 0.12531489, -0.37594467, 0.16708651, 0.13027048, -0.29735702], + ], + [ + [-0.02572276, 0.00857425, 0.00857425, 0.00857425, -0.02286468, 0.01143234, 0.01143234], + [0.13388914, 0.13388914, 0.13388914, -0.40166742, 0.17851885, -0.35703772, 0.17851885], + ], + ] + ] assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "tdt costs mismatch." assert np.allclose(pt_grads, expected_grads, rtol=1e-2), "td gradient mismatch." - if __name__ == "__main__": pytest.main([__file__]) From 7d41d362ffa04bb6771975d9d63f08ff3c16cfce Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 17 May 2023 15:48:59 -0400 Subject: [PATCH 28/40] add more comments for greedy-batch decoding for TDT Signed-off-by: Hainan Xu --- .../parts/submodules/rnnt_greedy_decoding.py | 214 ++---------------- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 45 ++-- 2 files changed, 30 insertions(+), 229 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index b1df2d360405..42b14fd7b8bf 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2630,13 +2630,16 @@ def _greedy_decode_blank_as_pad( # Get max sequence length max_out_len = out_len.max() + # skip means the number of frames the next decoding step should "jump" to. When skip == 1 + # it means the next decoding step will just use the next input frame. skip = 1 for time_idx in range(max_out_len): - if skip > 1: + if skip > 1: # if skip > 1 at the current step, we decrement it and skip the current frame. skip -= 1 continue f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + # need_to_stay is a boolean indicates whether the next decoding step should remain in the same frame. need_to_stay = True symbols_added = 0 @@ -2660,7 +2663,8 @@ def _greedy_decode_blank_as_pad( g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) # Batched joint step - Output = [B, V + 1 + num-big-blanks] - # If preserving per-frame confidence, log_normalize must be true + # Note: log_normalize must not be True here since the joiner output is contanetation of both token logits and duration logits, + # and they need to be normalized independently. joined = self._joint_step(f, g, log_normalize=None) logp = joined[:, 0, 0, : -len(self.durations)] duration_logp = joined[:, 0, 0, -len(self.durations) :] @@ -2669,15 +2673,20 @@ def _greedy_decode_blank_as_pad( logp = logp.float() duration_logp = duration_logp.float() - # Get index k, of max prob for batch + # get the max for both token and duration predictions. v, k = logp.max(1) dv, dk = duration_logp.max(1) + # here we set the skip value to be the minimum of all predicted durations, hense the "torch.min(dk)" call there. + # Please refer to Section 5.2 of our paper https://arxiv.org/pdf/2304.06795.pdf for explanation of this. skip = self.durations[int(torch.min(dk))] + # this is a special case: if all batches emit blanks, we require that skip be at least 1 + # so we don't loop forever at the current frame. if blank_mask.all(): if skip == 0: skip = 1 + need_to_stay = skip == 0 del g @@ -2697,7 +2706,6 @@ def _greedy_decode_blank_as_pad( # Recover prior state for all samples which predicted blank now/past if hidden is not None: - # LSTM has 2 states hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) elif len(blank_indices) > 0 and hidden is None: @@ -2754,200 +2762,4 @@ def _greedy_decode_masked( device: torch.device, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, ): - if partial_hypotheses is not None: - raise NotImplementedError("`partial_hypotheses` support is not supported") - - # x: [B, T, D] - # out_len: [B] - # device: torch.device - - # Initialize state - batchsize = x.shape[0] - hypotheses = [ - rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize) - ] - - # Initialize Hidden state matrix (shared by entire batch) - hidden = None - - # If alignments need to be preserved, register a danling list to hold the values - if self.preserve_alignments: - # alignments is a 3-dimensional dangling list representing B x T x U - for hyp in hypotheses: - hyp.alignments = [[]] - else: - hyp.alignments = None - - # If confidence scores need to be preserved, register a danling list to hold the values - if self.preserve_frame_confidence: - # frame_confidence is a 3-dimensional dangling list representing B x T x U - for hyp in hypotheses: - hyp.frame_confidence = [[]] - - # Last Label buffer + Last Label without blank buffer - # batch level equivalent of the last_label - last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) - last_label_without_blank = last_label.clone() - - # Mask buffers - blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) - - # Get max sequence length - max_out_len = out_len.max() - - with torch.inference_mode(): - for time_idx in range(max_out_len): - f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] - - # Prepare t timestamp batch variables - not_blank = True - symbols_added = 0 - - # Reset blank mask - blank_mask.mul_(False) - - # Update blank mask with time mask - # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) - # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len - blank_mask = time_idx >= out_len - - # Start inner loop - while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): - # Batch prediction and joint network steps - # If very first prediction step, submit SOS tag (blank) to pred_step. - # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state - if time_idx == 0 and symbols_added == 0 and hidden is None: - g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) - else: - # Set a dummy label for the blank value - # This value will be overwritten by "blank" again the last label update below - # This is done as vocabulary of prediction network does not contain "blank" token of RNNT - last_label_without_blank_mask = last_label >= self._blank_index - last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label - last_label_without_blank[~last_label_without_blank_mask] = last_label[ - ~last_label_without_blank_mask - ] - - # Perform batch step prediction of decoder, getting new states and scores ("g") - g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize) - - # Batched joint step - Output = [B, V + 1 + num-big-blanks] - # If preserving per-frame confidence, log_normalize must be true - logp = self._joint_step(f, g, log_normalize=True if self.preserve_frame_confidence else None)[ - :, 0, 0, : - ] - - if logp.dtype != torch.float32: - logp = logp.float() - - # Get index k, of max prob for batch - v, k = logp.max(1) - del g - - # Update blank mask with current predicted blanks - # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) - k_is_blank = k == self._blank_index - blank_mask.bitwise_or_(k_is_blank) - - # If preserving alignments, check if sequence length of sample has been reached - # before adding alignment - if self.preserve_alignments: - # Insert logprobs into last timestep per sample - logp_vals = logp.to('cpu') - logp_ids = logp_vals.max(1)[1] - for batch_idx in range(batchsize): - if time_idx < out_len[batch_idx]: - hypotheses[batch_idx].alignments[-1].append( - (logp_vals[batch_idx], logp_ids[batch_idx]) - ) - del logp_vals - - # If preserving per-frame confidence, check if sequence length of sample has been reached - # before adding confidence scores - if self.preserve_frame_confidence: - # Insert probabilities into last timestep per sample - confidence = self._get_confidence(logp) - for batch_idx in range(batchsize): - if time_idx < out_len[batch_idx]: - hypotheses[batch_idx].frame_confidence[-1].append(confidence[batch_idx]) - del logp - - # If all samples predict / have predicted prior blanks, exit loop early - # This is equivalent to if single sample predicted k - if blank_mask.all(): - not_blank = False - - # If preserving alignments, convert the current Uj alignments into a torch.Tensor - # Then preserve U at current timestep Ti - # Finally, forward the timestep history to Ti+1 for that sample - # All of this should only be done iff the current time index <= sample-level AM length. - # Otherwise ignore and move to next sample / next timestep. - if self.preserve_alignments: - - # convert Ti-th logits into a torch array - for batch_idx in range(batchsize): - - # this checks if current timestep <= sample-level AM length - # If current timestep > sample-level AM length, no alignments will be added - # Therefore the list of Uj alignments is empty here. - if len(hypotheses[batch_idx].alignments[-1]) > 0: - hypotheses[batch_idx].alignments.append([]) # blank buffer for next timestep - - # Do the same if preserving per-frame confidence - if self.preserve_frame_confidence: - - for batch_idx in range(batchsize): - if len(hypotheses[batch_idx].frame_confidence[-1]) > 0: - hypotheses[batch_idx].frame_confidence.append([]) # blank buffer for next timestep - else: - # Collect batch indices where blanks occurred now/past - blank_indices = (blank_mask == 1).nonzero(as_tuple=False) - - # Recover prior state for all samples which predicted blank now/past - if hidden is not None: - # LSTM has 2 states - hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) - - elif len(blank_indices) > 0 and hidden is None: - # Reset state if there were some blank and other non-blank predictions in batch - # Original state is filled with zeros so we just multiply - # LSTM has 2 states - hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0) - - # Recover prior predicted label for all samples which predicted blank now/past - k[blank_indices] = last_label[blank_indices, 0] - - # Update new label and hidden state for next iteration - last_label = k.view(-1, 1) - hidden = hidden_prime - - # Update predicted labels, accounting for time mask - # If blank was predicted even once, now or in the past, - # Force the current predicted label to also be blank - # This ensures that blanks propogate across all timesteps - # once they have occured (normally stopping condition of sample level loop). - for kidx, ki in enumerate(k): - if blank_mask[kidx] == 0: - hypotheses[kidx].y_sequence.append(ki) - hypotheses[kidx].timestep.append(time_idx) - hypotheses[kidx].score += float(v[kidx]) - - symbols_added += 1 - - # Remove trailing empty list of alignments at T_{am-len} x Uj - if self.preserve_alignments: - for batch_idx in range(batchsize): - if len(hypotheses[batch_idx].alignments[-1]) == 0: - del hypotheses[batch_idx].alignments[-1] - - # Remove trailing empty list of confidence scores at T_{am-len} x Uj - if self.preserve_frame_confidence: - for batch_idx in range(batchsize): - if len(hypotheses[batch_idx].frame_confidence[-1]) == 0: - del hypotheses[batch_idx].frame_confidence[-1] - - # Preserve states - for batch_idx in range(batchsize): - hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx) - - return hypotheses + raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.") diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index de1e66c641f4..3fbfcf6df54b 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -523,12 +523,11 @@ def test_case_randomized_act_label(self, device): assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt costs mismatch." assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." - @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) def test_case_fixed_case_act_label(self, device): if device == 'cuda': -# numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) B, T, U, V = 1, 3, 2, 3 # here V is number of non blank labels durations = [0, 1, 2] @@ -540,37 +539,27 @@ def test_case_fixed_case_act_label(self, device): fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) -# fn_ag = TDTLossPytorch( -# blank=V, reduction='sum', durations=durations, sigma=sigma -# ) # ag for automatic gradient computation -# ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) -# -# f = open('/tmp/1.txt', 'w') -# ag_cost = np.array2string(ag_cost, separator=',') -# ag_grads = np.array2string(ag_grads, separator=',') -# print(ag_cost, file=f) -# print(ag_grads, file=f) - expected_cost = 4.155739 - expected_grads =[[[[-0.64962804, 0.25 , 0.25 , 0.14962798, 0.2672583 , - -0.16792619,-0.09933221], - [ 0.01651875, 0.01651875, 0.01651875,-0.04955626, 0.022025 , - -0.01227201,-0.009753 ]], - - [[-0.04892651, 0.01714851, 0.01714851, 0.01462949,-0.01143234, - -0.01143234, 0.02286467], - [ 0.12531489, 0.12531489, 0.12531489,-0.37594467, 0.16708651, - 0.13027048,-0.29735702]], - - [[-0.02572276, 0.00857425, 0.00857425, 0.00857425,-0.02286468, - 0.01143234, 0.01143234], - [ 0.13388914, 0.13388914, 0.13388914,-0.40166742, 0.17851885, - -0.35703772, 0.17851885]]]] + expected_grads = [ + [ + [ + [-0.64962804, 0.25, 0.25, 0.14962798, 0.2672583, -0.16792619, -0.09933221], + [0.01651875, 0.01651875, 0.01651875, -0.04955626, 0.022025, -0.01227201, -0.009753], + ], + [ + [-0.04892651, 0.01714851, 0.01714851, 0.01462949, -0.01143234, -0.01143234, 0.02286467], + [0.12531489, 0.12531489, 0.12531489, -0.37594467, 0.16708651, 0.13027048, -0.29735702], + ], + [ + [-0.02572276, 0.00857425, 0.00857425, 0.00857425, -0.02286468, 0.01143234, 0.01143234], + [0.13388914, 0.13388914, 0.13388914, -0.40166742, 0.17851885, -0.35703772, 0.17851885], + ], + ] + ] assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "tdt costs mismatch." assert np.allclose(pt_grads, expected_grads, rtol=1e-2), "td gradient mismatch." - if __name__ == "__main__": pytest.main([__file__]) From fc0d3c36cd518d2c53c3e27d7cb1fc4d340a1456 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 17 May 2023 16:10:36 -0400 Subject: [PATCH 29/40] include config for TDT model with stateless decoders Signed-off-by: Hainan Xu --- .../tdt/conformer_tdt_bpe_stateless.yaml | 293 ++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml new file mode 100644 index 000000000000..65f4267fec66 --- /dev/null +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml @@ -0,0 +1,293 @@ +# This file contains the default values for training an TDT Conformer-Transducer ASR model, large size (~120M) with sub-word encoding. + +# You can find detailed info about TDT models at https://arxiv.org/abs/2304.06795. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+------------------+--------------+--------------------------+-----------------+ +# | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | pred_rnn_layers | +# +==============+=========+========+===========+==================+==============+==========================+=================+ +# | Small (14M)| 176 | 4 | 16 | 31 | 0.0 | 320 | 1 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ +# | Medium (32M)| 256 | 4 | 16 | 31 | 1e-3 | 640 | 1 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ +# | Large (120M)| 512 | 8 | 17 | 31 | 1e-3 | 640 | 1 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ +# | XLarge (644M)| 1024 | 8 | 24 | 5 | 1e-3 | 640 | 2 | +# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ + +# Default learning parameters in this config are set for global batch size of 2K while you may use lower values. +# To increase the global batch size with limited number of GPUs, you may use higher accumulate_grad_batches. +# However accumulate_grad_batches is better to be avoided as long as the global batch size is large enough and training is stable. + +# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer +# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html +# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large + +# We suggest to use trainer.precision=bf16 for GPUs which support it otherwise trainer.precision=16 is recommended. +# Using bf16 or 16 would make it possible to double the batch size and speedup training/inference. If fp16 is not stable and model diverges after some epochs, you may use fp32. +# Here are the suggested batch size per GPU for each precision and memory sizes: +# +-----------+------------+------------+ +# | Precision | GPU Memory | Batch Size | +# +===========+============+============+ +# | 32 | 16GB | 8 | +# | | 32GB | 16 | +# | | 80GB | 32 | +# +-----------+------------+------------+ +# | 16 or | 16GB | 16 | +# | bf16 | 32GB | 32 | +# | | 80GB | 64 | +# +-----------+------------+------------+ +# Note: They are based on the assumption of max_duration of 20. If you have longer or shorter max_duration, then batch sizes may need to get updated accordingly. + +name: "Conformer-TDT-BPE-Stateless" + +model: + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + # variables for TDT configs. + tdt_durations: [0, 1, 2, 3, 4] + num_tdt_durations: 5 + + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 4 # must be power of 2 for striding and vggnet + subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.StatelessTransducerDecoder + context_size: 2 # The Stateless decoder uses 2 words as context by default. + normalization_mode: layer # This helps stabilize training for Stateless decoders. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 16 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + num_extra_outputs: ${model.model_defaults.num_tdt_durations} + + decoding: + # Using greedy decoding is highly recommended for TDT models. Using greedy-batch will give very bad results + # if omega is 0; even if omega is non-zero, greedy-batch results are still going to be inaccurate. + strategy: "greedy" + + # this must not be None in order to use the TDT specific decoding method. + durations: ${model.model_defaults.tdt_durations} + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + loss: + # This is the main different between a TDT model and a conventional RNNT model -- the loss function. + loss_name: "tdt" + + tdt_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to reduce the latency of the model for streaming + fastemit_lambda: 0.001 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + # refer to https://arxiv.org/abs/2304.06795 for the meaning of the following three configs. + durations: ${model.model_defaults.tdt_durations} + sigma: 0.05 # hyper-param for under-normalization. + omega: 0.1 # weight for regular RNN-T loss. + + # Adds Gaussian noise to the gradients of the decoder to avoid overfitting + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 500 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + From 709ab0e29c332a765bd0c32465f22473edc13e5f Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 17 May 2023 16:20:49 -0400 Subject: [PATCH 30/40] add reference to TDT in Readme Signed-off-by: Hainan Xu --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 1335620ead25..c402a76baa52 100644 --- a/README.rst +++ b/README.rst @@ -84,7 +84,7 @@ Key Features * CTC * Transducer/RNNT * Hybrid Transducer/CTC - * NeMo Original `Multi-blank Transducers `_ + * NeMo Original `Multi-blank Transducers `_ and `Token-and-Duration Transducers (TDT) `_ * Streaming/Buffered ASR (CTC/Transducer) - `Chunked Inference Examples `_ * Cache-aware Streaming Conformer - ``_ * Beam Search decoding From c2b23c020cdee251589a2cdde19ad867acc26e58 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 19 May 2023 12:03:52 -0400 Subject: [PATCH 31/40] slight modification of config file comments Signed-off-by: Hainan Xu --- .../conf/conformer/tdt/conformer_tdt_bpe.yaml | 29 ++++----------- .../tdt/conformer_tdt_bpe_stateless.yaml | 36 +++++-------------- 2 files changed, 15 insertions(+), 50 deletions(-) diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml index 427a7af463a9..cd10d0e63a14 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml @@ -1,4 +1,4 @@ -# This file contains the default values for training an TDT Conformer-Transducer ASR model, large size (~120M) with sub-word encoding. +# This file contains the default values for training a Conformer-TDT ASR model, large size (~120M) with sub-word encoding. # You can find detailed info about TDT models at https://arxiv.org/abs/2304.06795. @@ -6,7 +6,12 @@ # Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective # batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. # Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. -# + +# Note: the added duration outputs from the joiner make TDT models slightly larger than corresponding conventional RNN-T models, +# although the difference is tiny -- the added number of params is roughly num-durations X (joint_hidden + pred_hidden), typically in the +# order of thousands of params. This is negligible even with the "Small" config with around 14 million params. +# Recommended duraction config is [0, 1, 2, ... , n] where optimal n is usually between 4 and 8 depending on the dataset. + # +--------------+---------+---------+----------+------------------+--------------+--------------------------+-----------------+ # | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | pred_rnn_layers | # +==============+=========+========+===========+==================+==============+==========================+=================+ @@ -23,26 +28,6 @@ # To increase the global batch size with limited number of GPUs, you may use higher accumulate_grad_batches. # However accumulate_grad_batches is better to be avoided as long as the global batch size is large enough and training is stable. -# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer -# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html -# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large - -# We suggest to use trainer.precision=bf16 for GPUs which support it otherwise trainer.precision=16 is recommended. -# Using bf16 or 16 would make it possible to double the batch size and speedup training/inference. If fp16 is not stable and model diverges after some epochs, you may use fp32. -# Here are the suggested batch size per GPU for each precision and memory sizes: -# +-----------+------------+------------+ -# | Precision | GPU Memory | Batch Size | -# +===========+============+============+ -# | 32 | 16GB | 8 | -# | | 32GB | 16 | -# | | 80GB | 32 | -# +-----------+------------+------------+ -# | 16 or | 16GB | 16 | -# | bf16 | 32GB | 32 | -# | | 80GB | 64 | -# +-----------+------------+------------+ -# Note: They are based on the assumption of max_duration of 20. If you have longer or shorter max_duration, then batch sizes may need to get updated accordingly. - name: "Conformer-TDT-BPE" model: diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml index 65f4267fec66..6c10081a9289 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml @@ -6,42 +6,22 @@ # Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective # batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. # Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. -# + +# Note: the added duration outputs from the joiner make TDT models slightly larger than corresponding conventional RNN-T models, +# although the difference is tiny -- the added number of params is roughly num-durations X (joint_hidden + pred_hidden), typically in the +# order of thousands of params. This is negligible even with the "Small" config with around 14 million params. +# Recommended duraction config is [0, 1, 2, ... , n] where optimal n is usually between 4 and 8 depending on the dataset. + # +--------------+---------+---------+----------+------------------+--------------+--------------------------+-----------------+ -# | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | pred_rnn_layers | +# | Model | d_model | n_heads | n_layers | conv_kernel_size | weight_decay | pred_hidden/joint_hidden | decoder_context | # +==============+=========+========+===========+==================+==============+==========================+=================+ -# | Small (14M)| 176 | 4 | 16 | 31 | 0.0 | 320 | 1 | -# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ -# | Medium (32M)| 256 | 4 | 16 | 31 | 1e-3 | 640 | 1 | -# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ -# | Large (120M)| 512 | 8 | 17 | 31 | 1e-3 | 640 | 1 | +# | Large (117M)| 512 | 8 | 17 | 31 | 1e-3 | 640 | 2 | # +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ -# | XLarge (644M)| 1024 | 8 | 24 | 5 | 1e-3 | 640 | 2 | -# +--------------+---------+--------+-----------+------------------+--------------+--------------------------+-----------------+ # Default learning parameters in this config are set for global batch size of 2K while you may use lower values. # To increase the global batch size with limited number of GPUs, you may use higher accumulate_grad_batches. # However accumulate_grad_batches is better to be avoided as long as the global batch size is large enough and training is stable. -# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer -# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html -# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large - -# We suggest to use trainer.precision=bf16 for GPUs which support it otherwise trainer.precision=16 is recommended. -# Using bf16 or 16 would make it possible to double the batch size and speedup training/inference. If fp16 is not stable and model diverges after some epochs, you may use fp32. -# Here are the suggested batch size per GPU for each precision and memory sizes: -# +-----------+------------+------------+ -# | Precision | GPU Memory | Batch Size | -# +===========+============+============+ -# | 32 | 16GB | 8 | -# | | 32GB | 16 | -# | | 80GB | 32 | -# +-----------+------------+------------+ -# | 16 or | 16GB | 16 | -# | bf16 | 32GB | 32 | -# | | 80GB | 64 | -# +-----------+------------+------------+ -# Note: They are based on the assumption of max_duration of 20. If you have longer or shorter max_duration, then batch sizes may need to get updated accordingly. name: "Conformer-TDT-BPE-Stateless" From dcf4cbba49fece6b516adb19a062dea94351a083 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Tue, 23 May 2023 15:16:35 -0400 Subject: [PATCH 32/40] addressed more comments Signed-off-by: Hainan Xu --- .../tdt/conformer_tdt_bpe_stateless.yaml | 3 + examples/asr/speech_to_text_eval.py | 2 +- .../utils/cuda_utils/gpu_rnnt_kernel.py | 134 +++++++++++------- 3 files changed, 84 insertions(+), 55 deletions(-) diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml index 6c10081a9289..46b13ef09ebd 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml @@ -172,6 +172,9 @@ model: joint_hidden: ${model.model_defaults.joint_hidden} activation: "relu" dropout: 0.2 + + # this variable is non-zero for this TDT model, as well as multi-blank models. It represents the number of + # additional outputs from the joiner, besides all tokens in the BPE vocab plus the (standard) blank symbol. num_extra_outputs: ${model.model_defaults.num_tdt_durations} decoding: diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index f4d2a66ffec0..0713d5cc5e36 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 656afdb1512f..5de553290ab6 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -35,7 +35,7 @@ GPU_RNNT_THREAD_SIZE = 256 -INF = 99999.9 +INF = 10000.0 @cuda.jit(device=True, inline=True) @@ -64,6 +64,12 @@ def logp( return denom[col] + acts[col * alphabet_size + v] +@cuda.jit(device=True, inline=True) +def logp_duration(acts: torch.Tensor, maxT: int, maxU: int, num_durations: int, mb: int, t: int, u: int, v: int): + col = (mb * maxT + t) * maxU + u + return acts[col * num_durations + v] + + @cuda.jit() def compute_alphas_kernel( acts: torch.Tensor, @@ -879,12 +885,6 @@ def compute_multiblank_grad_kernel( idx += GPU_RNNT_THREAD_SIZE -@cuda.jit(device=True, inline=True) -def logp_duration(acts: torch.Tensor, maxT: int, maxU: int, num_durations: int, mb: int, t: int, u: int, v: int): - col = (mb * maxT + t) * maxU + u - return acts[col * num_durations + v] - - @cuda.jit() def compute_tdt_alphas_kernel( acts: torch.Tensor, @@ -957,70 +957,87 @@ def compute_tdt_alphas_kernel( t = n - u if u == 0: - # for t in range(1, T) step to initialize alphas[b, t, 0] + # when u == 0, we only consider blank emissions. if t > 0 and t < T: alphas[offset + t * maxU + u] = -INF for i in range(1, num_durations): # skip 0 since blank emission has to advance by at least one if t >= durations[i]: alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( - alphas[offset + t * maxU + u], - alphas[offset + (t - durations[i]) * maxU + u] - + logp(denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u, blank_) - - sigma - + logp_duration(duration_acts, maxT, maxU, num_durations, b, t - durations[i], u, i), + alphas[offset + t * maxU + u], # the current alpha value + alphas[offset + (t - durations[i]) * maxU + u] # alpha(t - duration, u) + + logp( + denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u, blank_ + ) # logp of blank emission + - sigma # logit under-normalization + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t - durations[i], u, i + ), # logp of duration ) else: - break # durations are in ascending order + break # since durations are in ascending order, when we encounter a duration that is too large, then + # there is no need to check larger durations after that. elif u < U: - # for u in range(1, U) step to initialize alphas[b, 0, u] + # when t == 0, we only consider the non-blank emission. if t == 0: alphas[offset + u] = ( - alphas[offset + u - 1] - + logp(denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1]) - - sigma + alphas[offset + u - 1] # alpha(t, u - 1) + + logp( + denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1] + ) # logp of token emission + - sigma # logit under-normalization + logp_duration( duration_acts, maxT, maxU, num_durations, b, t, u - 1, 0 - ) # t = 0 so it must be duration = 0 only when duration_id = 0 + ) # t = 0, so it must be duration = 0. Therefore the last argument passed to logp_duration() is 0. ) - # for t in range(1, T) for u in range(1, U) step to compute alphas[b, t, u] + # now we have t != 0 and u != 0, and we need to consider both non-blank and blank emissions. elif t > 0 and t < T: - no_emit = -INF + no_emit = -INF # no_emit stores the score for all blank emissions. for i in range(1, num_durations): if t >= durations[i]: no_emit = rnnt_helper.log_sum_exp( - no_emit, - alphas[offset + (t - durations[i]) * maxU + u] - + logp(denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u, blank_) - - sigma - + logp_duration(duration_acts, maxT, maxU, num_durations, b, t - durations[i], u, i), + no_emit, # current score + alphas[offset + (t - durations[i]) * maxU + u] # alpha(t - duration, u) + + logp( + denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u, blank_ + ) # logp of blank emission + - sigma # logit under-normalization + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t - durations[i], u, i + ), # logp of duration ) else: - break + break # we can exit the loop early here, same as the case for u == 0 above. - emit = -INF + emit = -INF # emit stores the score for non-blank emissions. for i in range(0, num_durations): if t >= durations[i]: emit = rnnt_helper.log_sum_exp( - emit, - alphas[offset + (t - durations[i]) * maxU + u - 1] - + logp(denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u - 1, labels[u - 1]) - - sigma - + logp_duration(duration_acts, maxT, maxU, num_durations, b, t - durations[i], u - 1, i), + emit, # current score + alphas[offset + (t - durations[i]) * maxU + u - 1] # alpha(t - duration, u - 1) + + logp( + denom, acts, maxT, maxU, alphabet_size, b, t - durations[i], u - 1, labels[u - 1] + ) # logp of non-blank emission + - sigma # logit under-normalization + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t - durations[i], u - 1, i + ), # logp of duration ) else: - break + break # we can exit the loop early here, same as the case for u == 0 above. + # combining blank and non-blank emissions. alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) # sync across all B=b and U=u cuda.syncthreads() - # After final sync, alphas[b, T-1, U - 1] + logprobs[b, T-1, U-1, blank] + denom[b, T-1, U-1] gives - # log-likelihood of forward pass. + # After final sync, the forward log-likelihood can be computed as the summataion of + # alpha(T - duration, U - 1) + logp(blank, duration | t - duration, U - 1), over different durations. if u == 0: + # first we consider duration = 1 loglike = ( alphas[offset + (T - 1) * maxU + U - 1] + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) @@ -1028,6 +1045,7 @@ def compute_tdt_alphas_kernel( + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 1) ) + # then we add the scores for duration > 1, if such durations are possible given the audio lengths. for i in range(2, num_durations): if T >= durations[i]: big_blank_loglike = ( @@ -1118,30 +1136,37 @@ def compute_tdt_betas_kernel( for n in range(T + U - 2, -1, -1): t = n - u - if u == (U - 1): - # for t in reversed(range(T - 1)) step to initialize betas[b, t, U-1] + if u == U - 1: + # u == U - 1, we only consider blank emissions. if t >= 0 and t + 1 < T: betas[offset + t * maxU + U - 1] = -INF for i in range(1, num_durations): - if t + durations[i] < T: # recursive beta computation. + # although similar, the computation for beta's is slightly more complex for boundary cases. + # the following two cases correspond to whether t is exactly certain duration away from T. + # and they have slightly different update rules. + + if t + durations[i] < T: betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], - betas[offset + (t + durations[i]) * maxU + U - 1] + betas[ + offset + (t + durations[i]) * maxU + U - 1 + ] # beta[t, U - 1] uses the value beta[t + duration, U - 1] here. + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, U - 1, i) - sigma, ) - elif t + durations[i] == T: # beta base case + elif t + durations[i] == T: betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], + # we could see this as having "0" here as beta[t + duration, U - 1], which isn't defined since t + duration is out of bound. logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, U - 1, i) - sigma, ) - elif u < U: + elif u < U - 1: if t == T - 1: - # for u in reversed(range(U - 1)) step to initialize betas[b, T-1, u] + # t == T - 1, so we only consider non-blank with duration 0. betas[offset + (T - 1) * maxU + u] = ( betas[offset + (T - 1) * maxU + u + 1] + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) @@ -1149,8 +1174,8 @@ def compute_tdt_betas_kernel( - sigma ) - elif (t >= 0) and (t < T - 1): - # for t in reversed(range(T - 1)) for u in reversed(range(U - 1)) step to compute betas[b, t, u] + elif t >= 0 and t < T - 1: + # now we need to consider both blank andnon-blanks. Similar to alphas, we first compute them separately with no_emit and emit. no_emit = -INF for i in range(1, num_durations): if t + durations[i] < T: @@ -1178,8 +1203,7 @@ def compute_tdt_betas_kernel( # sync across all B=b and U=u cuda.syncthreads() - # After final sync, betas[b, 0, 0] gives - # log-likelihood of backward pass. + # After final sync, betas[b, 0, 0] gives log-likelihood of backward pass, same with conventional Transducers. if u == 0: llBackward[b] = betas[offset] @@ -1264,7 +1288,9 @@ def compute_tdt_grad_kernel( # Look up gradient calculation from rnnt_numpy.compute_gradient() if t < T and u < U: - logpk_blank = denom[col] + acts[col * alphabet_size + blank_] - sigma + logpk_blank = ( + denom[col] + acts[col * alphabet_size + blank_] - sigma + ) # whenever sigma is used, it is for logit under-normalization. if idx < num_durations: grad = 0.0 @@ -1291,9 +1317,9 @@ def compute_tdt_grad_kernel( while idx < alphabet_size: # remember, `col` represents the tri-index [b, t, u] # therefore; logpk = denom[b, t, u] + acts[b, t, u, v] - logpk = denom[col] + acts[col * alphabet_size + idx] # - sigma + logpk = denom[col] + acts[col * alphabet_size + idx] # initialize the grad of the sample acts[b, t, u, v] - grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) # * math.exp(sigma) + grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) # If FastEmit regularization is enabled, calculate the gradeint of probability of predicting the next label # at the current timestep. @@ -1311,7 +1337,7 @@ def compute_tdt_grad_kernel( + duration_acts[col * num_durations + i] + betas[col + 1 + durations[i] * maxU] # betas(t, u+1) + logpk # log Pr(k|t, u) - - sigma # y_hat(t, u) + - sigma # for logit under-normalization - logll[mb] # total log likelihood for normalization ) else: @@ -1320,9 +1346,9 @@ def compute_tdt_grad_kernel( # Update the gradient of act[b, t, u, v] with the gradient from FastEmit regularization grad = grad + fastemit_grad - # // grad to last blank transition - # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) - if (idx == blank_) and (u == U - 1): + # grad to last blank transition + # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u] + logpk - sigma - logll[b] + logp(duration) for all possible non-zero durations. + if idx == blank_ and u == U - 1: for i in range(1, num_durations): if t == T - durations[i]: grad -= math.exp( From ab47b8873c8a0bd57abbebc4e168df6f3406269c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 May 2023 19:17:33 +0000 Subject: [PATCH 33/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index 0713d5cc5e36..f4d2a66ffec0 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner From d7b7307c0b24a6194008856eeb20ed656330f431 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 25 May 2023 14:38:07 -0400 Subject: [PATCH 34/40] more detailed comments for tdt kernel Signed-off-by: Hainan Xu --- .../utils/cuda_utils/gpu_rnnt_kernel.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 5de553290ab6..3e0513092bb8 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -1150,28 +1150,33 @@ def compute_tdt_betas_kernel( betas[offset + t * maxU + U - 1], betas[ offset + (t + durations[i]) * maxU + U - 1 - ] # beta[t, U - 1] uses the value beta[t + duration, U - 1] here. - + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) - + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, U - 1, i) - - sigma, + ] # beta[t, U - 1] depends on the value beta[t + duration, U - 1] here. + + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) # log prob of blank + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t, U - 1, i + ) # log prob of duration (durations[i]) + - sigma, # for logit undernormalization ) elif t + durations[i] == T: betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], - # we could see this as having "0" here as beta[t + duration, U - 1], which isn't defined since t + duration is out of bound. - logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) - + logp_duration(duration_acts, maxT, maxU, num_durations, b, t, U - 1, i) - - sigma, + # here we have one fewer term than the "if" block above. This could be seen as having "0" here since + # beta[t + duration, U - 1] isn't defined because t + duration is out of bound. + logp(denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_) # log prob of blank + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t, U - 1, i + ) # log prob of duration (durations[i]) + - sigma, # for logit undernormalization. Basically every time sigma shows up is because of logit undernormalization. ) elif u < U - 1: if t == T - 1: - # t == T - 1, so we only consider non-blank with duration 0. + # t == T - 1, so we only consider non-blank with duration 0. (Note, we can't have blank emissions with duration = 0) betas[offset + (T - 1) * maxU + u] = ( betas[offset + (T - 1) * maxU + u + 1] - + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) - + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, u, 0) - - sigma + + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) # non-blank log prob + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, u, 0) # log prob of duration 0 + - sigma, ) elif t >= 0 and t < T - 1: @@ -1198,6 +1203,7 @@ def compute_tdt_betas_kernel( - sigma, ) + # combining all blank emissions and all non-blank emissions. betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) # sync across all B=b and U=u @@ -1333,8 +1339,8 @@ def compute_tdt_grad_kernel( if t + durations[i] < T: fastemit_grad += fastemit_lambda * math.exp( alphas[col] # alphas(t, u) - + (denom[col] + acts[col * alphabet_size + labels[u]]) - + duration_acts[col * num_durations + i] + + (denom[col] + acts[col * alphabet_size + labels[u]]) # log prob of token emission + + duration_acts[col * num_durations + i] # duration log-prob + betas[col + 1 + durations[i] * maxU] # betas(t, u+1) + logpk # log Pr(k|t, u) - sigma # for logit under-normalization @@ -1356,7 +1362,7 @@ def compute_tdt_grad_kernel( ) # grad of blank across t < T; - # grad[b, t 0.0: - g = label_grads[col * (alphabet_size) + idx] + g = label_grads[col * alphabet_size + idx] g = min(g, clamp) g = max(g, -clamp) - label_grads[col * (alphabet_size) + idx] = g + label_grads[col * alphabet_size + idx] = g # update internal index through the thread_buffer; # until idx < V + 1, such that entire vocabulary has been updated. From 91670328d664bdfa71d5f047b42721a77a6b8f45 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 1 Jun 2023 13:15:35 -0400 Subject: [PATCH 35/40] one line fix Signed-off-by: Hainan Xu --- .../parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 3e0513092bb8..4153af060941 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -1176,7 +1176,7 @@ def compute_tdt_betas_kernel( betas[offset + (T - 1) * maxU + u + 1] + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) # non-blank log prob + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, u, 0) # log prob of duration 0 - - sigma, + - sigma ) elif t >= 0 and t < T - 1: From fb5f73e9f8ab3e4fc57a31e9b826f2c5517707cd Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 2 Jun 2023 10:04:39 -0400 Subject: [PATCH 36/40] fixed small bug that results in test fails for rnnt_decoding Signed-off-by: Hainan Xu --- nemo/collections/asr/metrics/rnnt_wer.py | 5 ++++- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index cddcb6f34b50..0de59e67d5ce 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -1107,7 +1107,7 @@ def __init__( # we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT. blank_id = len(vocabulary) + joint.num_extra_outputs - if 'durations' in decoding_cfg and decoding_cfg['durations'] is not None: # this means it's a TDT model. + if decoding_cfg.durations is not None: # this means it's a TDT model. blank_id = len(vocabulary) self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) @@ -1320,3 +1320,6 @@ class RNNTDecodingConfig: # can be used to change temperature for decoding temperature: float = 1.0 + + big_blank_durations: Optional[List] = None + durations: Optional[List] = None diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 2880844697d9..2c71b0885c9b 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -199,7 +199,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): blank_id = tokenizer.tokenizer.vocab_size # RNNT or TDT models. # multi-blank RNNTs - if 'big_blank_durations' in decoding_cfg: + if decoding_cfg.big_blank_durations is not None: blank_id = tokenizer.tokenizer.vocab_size + joint.num_extra_outputs self.tokenizer = tokenizer From 3d81ff17f6c6ace33acbe5b1e26898449407bcd5 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 2 Jun 2023 10:58:05 -0400 Subject: [PATCH 37/40] fixed small bug that results in test fails for rnnt_decoding Signed-off-by: Hainan Xu --- .../multiblank/conformer_multiblank_transducer_bpe.yaml | 1 + examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml | 2 ++ .../conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml | 2 ++ examples/asr/speech_to_text_eval.py | 2 +- nemo/collections/asr/metrics/rnnt_wer.py | 9 ++++----- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 2 +- tests/collections/asr/decoding/test_rnnt_decoding.py | 2 +- 7 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/asr/conf/conformer/multiblank/conformer_multiblank_transducer_bpe.yaml b/examples/asr/conf/conformer/multiblank/conformer_multiblank_transducer_bpe.yaml index 84d767e4a3b5..51e57e72e2ad 100644 --- a/examples/asr/conf/conformer/multiblank/conformer_multiblank_transducer_bpe.yaml +++ b/examples/asr/conf/conformer/multiblank/conformer_multiblank_transducer_bpe.yaml @@ -179,6 +179,7 @@ model: decoding: strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + model_type: "multiblank" # this must not be None in order to use the multi-blank specific decoding method. # you could set this to [1, 1, 1] so that big blanks are treated the same diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml index cd10d0e63a14..0210bd5a2dad 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml @@ -185,6 +185,8 @@ model: # if omega is 0; even if omega is non-zero, greedy-batch results are still going to be inaccurate. strategy: "greedy" + model_type: "tdt" + # this must not be None in order to use the TDT specific decoding method. durations: ${model.model_defaults.tdt_durations} diff --git a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml index 46b13ef09ebd..fefbd6f8f56c 100644 --- a/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml +++ b/examples/asr/conf/conformer/tdt/conformer_tdt_bpe_stateless.yaml @@ -182,6 +182,8 @@ model: # if omega is 0; even if omega is non-zero, greedy-batch results are still going to be inaccurate. strategy: "greedy" + model_type: "tdt" + # this must not be None in order to use the TDT specific decoding method. durations: ${model.model_defaults.tdt_durations} diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index f4d2a66ffec0..0713d5cc5e36 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch -import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict +import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 0de59e67d5ce..f85956324a30 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -15,7 +15,7 @@ import copy import re from abc import abstractmethod -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass, field, is_dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import editdistance @@ -1107,7 +1107,7 @@ def __init__( # we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT. blank_id = len(vocabulary) + joint.num_extra_outputs - if decoding_cfg.durations is not None: # this means it's a TDT model. + if decoding_cfg.model_type == 'tdt': blank_id = len(vocabulary) self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) @@ -1288,7 +1288,9 @@ def compute(self): @dataclass class RNNTDecodingConfig: + model_type: str = "rnnt" # one of "rnnt", "multiblank" or "tdt" strategy: str = "greedy_batch" + compute_hypothesis_token_set: bool = False # preserve decoding alignments @@ -1320,6 +1322,3 @@ class RNNTDecodingConfig: # can be used to change temperature for decoding temperature: float = 1.0 - - big_blank_durations: Optional[List] = None - durations: Optional[List] = None diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 2c71b0885c9b..391a0b090e5f 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -199,7 +199,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): blank_id = tokenizer.tokenizer.vocab_size # RNNT or TDT models. # multi-blank RNNTs - if decoding_cfg.big_blank_durations is not None: + if decoding_cfg.model_type == 'multiblank': blank_id = tokenizer.tokenizer.vocab_size + joint.num_extra_outputs self.tokenizer = tokenizer diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 9dd955c24a70..ac90e62036e0 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -130,7 +130,7 @@ def test_constructor(self): @pytest.mark.unit def test_constructor_subword(self, tmp_tokenizer): - cfg = RNNTBPEDecodingConfig() + cfg = RNNTDecodingConfig() vocab = tmp_tokenizer.vocab decoder = get_rnnt_decoder(vocab_size=len(vocab)) joint = get_rnnt_joint(vocab_size=len(vocab)) From 4eb8bfa24256460ff1506744de8e082b3ba45bff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 15:00:44 +0000 Subject: [PATCH 38/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/asr/speech_to_text_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index 0713d5cc5e36..f4d2a66ffec0 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -62,9 +62,9 @@ from typing import Optional import torch +import transcribe_speech from omegaconf import MISSING, OmegaConf, open_dict -import transcribe_speech from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig from nemo.core.config import hydra_runner From 9b8f0577b89653a4b540c269ced153a4ef7f6632 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 2 Jun 2023 12:02:16 -0400 Subject: [PATCH 39/40] fixed small bug that results in test fails for rnnt_decoding Signed-off-by: Hainan Xu --- nemo/collections/asr/metrics/rnnt_wer.py | 2 +- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index f85956324a30..02e1a343ab00 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -1107,7 +1107,7 @@ def __init__( # we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT. blank_id = len(vocabulary) + joint.num_extra_outputs - if decoding_cfg.model_type == 'tdt': + if hasattr(decoding_cfg, 'model_type') and decoding_cfg.model_type == 'tdt': blank_id = len(vocabulary) self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 391a0b090e5f..0870eb180776 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -199,7 +199,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): blank_id = tokenizer.tokenizer.vocab_size # RNNT or TDT models. # multi-blank RNNTs - if decoding_cfg.model_type == 'multiblank': + if hasattr(decoding_cfg, 'model_type') and decoding_cfg.model_type == 'multiblank': blank_id = tokenizer.tokenizer.vocab_size + joint.num_extra_outputs self.tokenizer = tokenizer From c85d143dfa2240c22d8b1d338d4746bf4b030017 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Fri, 2 Jun 2023 13:21:00 -0400 Subject: [PATCH 40/40] remove unused import Signed-off-by: Hainan Xu --- nemo/collections/asr/metrics/rnnt_wer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 02e1a343ab00..55f9f4b5ea9f 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -15,7 +15,7 @@ import copy import re from abc import abstractmethod -from dataclasses import dataclass, field, is_dataclass +from dataclasses import dataclass, is_dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import editdistance