diff --git a/Jenkinsfile b/Jenkinsfile index 1381d0c41c82..9f174ce1c908 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -562,6 +562,54 @@ pipeline { // sh 'rm -rf examples/asr/speech_to_text_rnnt_wpe_results' // } // } + // stage('L3: Speech to Text Hybrid Transducer-CTC WPE') { + // steps { + // sh 'STRICT_NUMBA_COMPAT_CHECK=false python examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py \ + // --config-path="../conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc/" --config-name="conformer_hybrid_transducer_ctc_bpe.yaml" \ + // model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + // model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + // model.encoder.n_layers= 2 \ + // model.train_ds.batch_size=2 \ + // model.validation_ds.batch_size=2 \ + // model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \ + // model.tokenizer.type="wpe" \ + // trainer.devices=[0] \ + // trainer.accelerator="gpu" \ + // +trainer.fast_dev_run=True \ + // exp_manager.exp_dir=examples/asr/speech_to_text_hybrid_transducer_ctc_wpe_results' + // sh 'rm -rf examples/asr/speech_to_text_hybrid_transducer_ctc_wpe_results' + // } + // } + // } + // } + + // stage('L2: Hybrid ASR RNNT-CTC dev run') { + // when { + // anyOf { + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // parallel { + // stage('Speech to Text Hybrid Transducer-CTC WPE') { + // steps { + // sh 'STRICT_NUMBA_COMPAT_CHECK=false python examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py \ + // --config-path="../conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc/" --config-name="conformer_hybrid_transducer_ctc_bpe.yaml" \ + // model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + // model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + // model.encoder.n_layers= 2 \ + // model.train_ds.batch_size=2 \ + // model.validation_ds.batch_size=2 \ + // model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \ + // model.tokenizer.type="wpe" \ + // trainer.devices=[0] \ + // trainer.accelerator="gpu" \ + // +trainer.fast_dev_run=True \ + // exp_manager.exp_dir=examples/asr/speech_to_text_hybrid_transducer_ctc_wpe_results' + // sh 'rm -rf examples/asr/speech_to_text_hybrid_transducer_ctc_wpe_results' + // } + // } // } // } diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index f2b26e84fea9..2df995340539 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -181,7 +181,7 @@ You may find the example config files of cache-aware streaming Conformer models ``/examples/asr/conf/conformer/streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and at ``/examples/asr/conf/conformer/streaming/conformer_ctc_bpe.yaml`` for CTC variant. -To simulate cache-aware stremaing, you may use the script at ``/examples/asr/asr_streaming/speech_to_text_streaming_infer.py``. It can simulate streaming in single stream or multi-stream mode (in batches) for an ASR model. +To simulate cache-aware streaming, you may use the script at ``/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py``. It can simulate streaming in single stream or multi-stream mode (in batches) for an ASR model. This script can be used for models trained offline with full-context but the accuracy would not be great unless the chunk size is large enough which would result in high latency. It is recommended to train a model in streaming model with limited context for this script. More info can be found in the script. @@ -236,6 +236,27 @@ You may find the example config files of Squeezeformer-CTC model with character- ``/examples/asr/conf/squeezeformer/squeezeformer_ctc_char.yaml`` and with sub-word encoding at ``/examples/asr/conf/squeezeformer/squeezeformer_ctc_bpe.yaml``. +.. _Hybrid-Transducer_CTC_model: + +Hybrid-Transducer-CTC +--------------------- + +Hybrid RNNT-CTC models is a group of models with both the RNNT and CTC decoders. Training a unified model would speedup the convergence for the CTC models and would enable +the user to use a single model which works as both a CTC and RNNT model. This category can be used with any of the ASR models. +Hybrid models uses two decoders of CTC and RNNT on the top of the encoder. The default decoding strategy after the training is done is RNNT. +User may use the ``asr_model.change_decoding_strategy(decoder_type='ctc' or 'rnnt')`` to change the default decoding. + +The variant with sub-word encoding is a BPE-based model +which can be instantiated using the :class:`~nemo.collections.asr.models.EncDecHybridRNNTCTCBPEModel` class, while the +character-based variant is based on :class:`~nemo.collections.asr.models.EncDecHybridRNNTCTCModel`. + +You may use the example scripts under ``/examples/asr/asr_hybrid_transducer_ctc`` for both the char-based encoding and sub-word encoding. +These examples can be used to train any Hybrid ASR model like Conformer, Citrinet, QuartzNet, etc. + +You may find the example config files of Conformer variant of such hybrid models with character-based encoding at +``/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml`` and +with sub-word encoding at ``/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml``. + References ---------- diff --git a/examples/asr/asr_streaming/speech_to_text_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py similarity index 100% rename from examples/asr/asr_streaming/speech_to_text_streaming_infer.py rename to examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py diff --git a/examples/asr/asr_hybrid_transducer_ctc/README.md b/examples/asr/asr_hybrid_transducer_ctc/README.md new file mode 100644 index 000000000000..bb52de45a9b4 --- /dev/null +++ b/examples/asr/asr_hybrid_transducer_ctc/README.md @@ -0,0 +1,32 @@ +# ASR with Hybrid Transducer/CTC Models + +This directory contains example scripts to train ASR models with two decoders of Transducer and CTC Loss. + +Currently supported models are - + +* Character based Hybrid RNNT/CTC model +* Subword based Hybrid RNNT/CTC model + +# Model execution overview + +The training scripts in this directory execute in the following order. When preparing your own training-from-scratch / fine-tuning scripts, please follow this order for correct training/inference. + +```mermaid + +graph TD + A[Hydra Overrides + Yaml Config] --> B{Config} + B --> |Init| C[Trainer] + C --> D[ExpManager] + B --> D[ExpManager] + C --> E[Model] + B --> |Init| E[Model] + E --> |Constructor| F1(Change Vocabulary) + F1 --> F2(Setup Adapters if available) + F2 --> G(Setup Train + Validation + Test Data loaders) + G --> H1(Setup Optimization) + H1 --> H2(Change Transducer Decoding Strategy) + H2 --> I[Maybe init from pretrained] + I --> J["trainer.fit(model)"] +``` + +During restoration of the model, you may pass the Trainer to the restore_from / from_pretrained call, or set it after the model has been initialized by using `model.set_trainer(Trainer)`. \ No newline at end of file diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py new file mode 100644 index 000000000000..2de150c71328 --- /dev/null +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Preparing the Tokenizer for the dataset +Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. + +```sh +python /scripts/tokenizers/process_asr_text_tokenizer.py \ + --manifest= + OR + --data_file= \ + --data_root="" \ + --vocab_size= \ + --tokenizer=<"spe" or "wpe"> \ + --no_lower_case \ + --spe_type=<"unigram", "bpe", "char" or "word"> \ + --spe_character_coverage=1.0 \ + --log +``` + +# Training the model +```sh +python speech_to_text_hybrid_rnnt_ctc_bpe.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.type= \ + model.aux_ctc.ctc_loss_weight=0.3 \ + trainer.devices=-1 \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +""" + +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner( + config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc_bpe" +) +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py new file mode 100644 index 000000000000..532e2c9ed0be --- /dev/null +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Training the model + +Basic run (on CPU for 50 epochs): + python examples/asr/asr_transducer/speech_to_text_hybrid_rnnt_ctc.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + trainer.devices=1 \ + trainer.accelerator='cpu' \ + trainer.max_epochs=50 + + +Add PyTorch Lightning Trainer arguments from CLI: + python speech_to_text_rnnt.py \ + ... \ + +trainer.fast_dev_run=true + +Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)" +PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)" + +Override some args of optimizer: + python speech_to_text_hybrid_rnnt_ctc.py \ + --config-path="../conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc" \ + --config-name="config_rnnt" \ + model.train_ds.manifest_filepath="./an4/train_manifest.json" \ + model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ + trainer.devices=2 \ + model.aux_ctc.ctc_loss_weight=0.3 \ + trainer.precision=16 \ + trainer.max_epochs=2 \ + model.optim.betas=[0.8,0.5] \ + model.optim.weight_decay=0.0001 + +Override optimizer entirely + python speech_to_text_hybrid_rnnt_ctc.py \ + --config-path="../conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc" \ + --config-name="config_rnnt" \ + model.train_ds.manifest_filepath="./an4/train_manifest.json" \ + model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ + model.aux_ctc.ctc_loss_weight=0.3 \ + trainer.devices=2 \ + trainer.precision=16 \ + trainer.max_epochs=2 \ + model.optim.name=adamw \ + model.optim.lr=0.001 \ + ~model.optim.args \ + +model.optim.args.betas=[0.8,0.5]\ + +model.optim.args.weight_decay=0.0005 + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +""" + +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecHybridRNNTCTCModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecHybridRNNTCTCModel(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/conf/conformer/streaming/conformer_ctc_bpe_streaming.yaml b/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe_streaming.yaml similarity index 100% rename from examples/asr/conf/conformer/streaming/conformer_ctc_bpe_streaming.yaml rename to examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe_streaming.yaml diff --git a/examples/asr/conf/conformer/streaming/conformer_transducer_bpe_streaming.yaml b/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml similarity index 100% rename from examples/asr/conf/conformer/streaming/conformer_transducer_bpe_streaming.yaml rename to examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml diff --git a/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml new file mode 100644 index 000000000000..3e03d3495174 --- /dev/null +++ b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_bpe.yaml @@ -0,0 +1,275 @@ +# It contains the default values for training a Conformer-Hybrid-Transducer-CTC ASR model, large size (~120M) with Transducer loss and sub-word encoding. +# The model would have two decoders: RNNT (Transducer) and CTC + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. +# +# +-------------+---------+---------+----------+--------------+--------------------------+ +# | Model | d_model | n_heads | n_layers | weight_decay | pred_hidden/joint_hidden | +# +=============+=========+========+===========+==============+==========================+ +# | Small (14M)| 176 | 4 | 16 | 0.0 | 320 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Medium (32M)| 256 | 4 | 16 | 1e-3 | 640 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Large (120M)| 512 | 8 | 17 | 1e-3 | 640 | +# +-----------------------------------------------------------+--------------------------+ +# + +# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer +# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html +# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large + +name: "Conformer-Hybrid-Transducer-CTC-BPE" + +model: + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling parameters + subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 4 # must be power of 2 for striding and vggnet + subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 16 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # The section which would contain the decoder and decoding configs of the auxiliary CTC decoder + aux_ctc: + ctc_loss_weight: 0.5 # the weight used to combine the CTC loss with the RNNT loss + use_cer: false + ctc_reduction: 'mean_batch' + decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: null + num_classes: -1 + vocabulary: [] + decoding: + strategy: "greedy" + + loss: + loss_name: "default" + + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to reduce the latency of the model for streaming + fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + # Adds Gaussian noise to the gradients of the decoder to avoid overfitting + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 500 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml new file mode 100644 index 000000000000..dbbde6875383 --- /dev/null +++ b/examples/asr/conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc_char.yaml @@ -0,0 +1,272 @@ +# It contains the default values for training a Conformer-Hybrid-Transducer-CTC ASR model, large size (~120M) with Transducer loss and char-based vocabulary. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. +# +# +-------------+---------+---------+----------+--------------+--------------------------+ +# | Model | d_model | n_heads | n_layers | weight_decay | pred_hidden/joint_hidden | +# +=============+=========+========+===========+==============+==========================+ +# | Small (14M)| 176 | 4 | 16 | 0.0 | 320 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Medium (32M)| 256 | 4 | 16 | 1e-3 | 640 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Large (120M)| 512 | 8 | 17 | 1e-3 | 640 | +# +-----------------------------------------------------------+--------------------------+ +# + +# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer +# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html +# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large + +name: "Conformer-Hybrid-Transducer-CTC-Char" + +model: + sample_rate: &sample_rate 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + labels: [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", + "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + trim_silence: false + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: *sample_rate + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 4 # must be power of 2 for striding and vggnet + subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 16 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # The section which would contain the decoder and decoding configs of the auxiliary CTC decoder + aux_ctc: + ctc_loss_weight: 0.5 # the weight used to combine the CTC loss with the RNNT loss + use_cer: false + ctc_reduction: 'mean_batch' + decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: null + num_classes: -1 + vocabulary: ${model.labels} + decoding: + strategy: "greedy" + + loss: + loss_name: "default" + + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to reduce the latency of the model for streaming + fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + # Adds Gaussian noise to the gradients of the decoder to avoid overfitting + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 500 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 76ff8e76dc40..f937acc89854 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -116,6 +116,9 @@ class TranscriptionConfig: # Decoding strategy for RNNT models rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Joint RNNT/CTC models + decoder_type: Optional[str] = None + @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: @@ -157,8 +160,12 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): + if cfg.decoder_type is not None: + asr_model.change_decoding_strategy( + cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding, decoder_type=cfg.decoder_type + ) # Check if ctc or rnnt model - if hasattr(asr_model, 'joint'): # RNNT model + elif hasattr(asr_model, 'joint'): # RNNT model cfg.rnnt_decoding.fused_batch_size = -1 cfg.rnnt_decoding.compute_langs = cfg.compute_langs asr_model.change_decoding_strategy(cfg.rnnt_decoding) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index c346555a10f5..82e7b0697bae 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -17,6 +17,8 @@ from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.collections.asr.models.k2_sequence_models import EncDecK2SeqModel, EncDecK2SeqModelBPE from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py new file mode 100644 index 000000000000..3b94084e0d8b --- /dev/null +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -0,0 +1,379 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Optional, Union + +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.losses.rnnt import RNNTLoss +from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding, RNNTBPEDecodingConfig +from nemo.collections.asr.metrics.wer_bpe import WERBPE, CTCBPEDecoding, CTCBPEDecodingConfig +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.core.classes.common import PretrainedModelInfo +from nemo.utils import logging, model_utils + + +class EncDecHybridRNNTCTCBPEModel(EncDecHybridRNNTCTCModel, ASRBPEMixin): + """Base class for encoder decoder RNNT-based models with auxiliary CTC decoder/loss and subword tokenization.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + # Tokenizer is necessary for this model + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + if not isinstance(cfg, DictConfig): + cfg = OmegaConf.create(cfg) + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + with open_dict(cfg): + cfg.labels = ListConfig(list(vocabulary)) + + with open_dict(cfg.decoder): + cfg.decoder.vocab_size = len(vocabulary) + + with open_dict(cfg.joint): + cfg.joint.num_classes = len(vocabulary) + cfg.joint.vocabulary = ListConfig(list(vocabulary)) + cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden + cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden + + # setup auxiliary CTC decoder + if 'aux_ctc' not in cfg: + raise ValueError( + "The config need to have a section for the CTC decoder named as aux_ctc for Hybrid models." + ) + + with open_dict(cfg): + if self.tokenizer_type == "agg": + cfg.aux_ctc.decoder.vocabulary = ListConfig(vocabulary) + else: + cfg.aux_ctc.decoder.vocabulary = ListConfig(list(vocabulary.keys())) + + if cfg.aux_ctc.decoder["num_classes"] < 1: + logging.info( + "\nReplacing placholder number of classes ({}) with actual number of classes - {}".format( + cfg.aux_ctc.decoder["num_classes"], len(vocabulary) + ) + ) + cfg.aux_ctc.decoder["num_classes"] = len(vocabulary) + + super().__init__(cfg=cfg, trainer=trainer) + + # Setup decoding object + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) + + # Setup wer object + self.wer = RNNTBPEWER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self.cfg.get('use_cer', False), + log_prediction=self.cfg.get('log_prediction', True), + dist_sync_on_step=True, + ) + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # Setup CTC decoding + ctc_decoding_cfg = self.cfg.aux_ctc.get('decoding', None) + if ctc_decoding_cfg is None: + ctc_decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig) + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoding = ctc_decoding_cfg + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) + + # Setup CTC WER + self.ctc_wer = WERBPE( + decoding=self.ctc_decoding, + use_cer=self.cfg.aux_ctc.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self.cfg.get("log_prediction", False), + ) + + # setting the RNNT decoder as the default one + self.use_rnnt_decoder = True + + def change_vocabulary( + self, + new_tokenizer_dir: Union[str, DictConfig], + new_tokenizer_type: str, + decoding_cfg: Optional[DictConfig] = None, + ctc_decoding_cfg: Optional[DictConfig] = None, + ): + """ + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + ctc_decoding_cfg: A config for auxiliary CTC decoding, which is optional and can be used to change the decoding type. + + Returns: None + + """ + if isinstance(new_tokenizer_dir, DictConfig): + if new_tokenizer_type == 'agg': + new_tokenizer_cfg = new_tokenizer_dir + else: + raise ValueError( + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + ) + else: + new_tokenizer_cfg = None + + if new_tokenizer_cfg is not None: + tokenizer_cfg = new_tokenizer_cfg + else: + if not os.path.isdir(new_tokenizer_dir): + raise NotADirectoryError( + f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' + ) + + if new_tokenizer_type.lower() not in ('bpe', 'wpe'): + raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') + + tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) + + # Setup the tokenizer + self._setup_tokenizer(tokenizer_cfg) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + joint_config = self.joint.to_config_dict() + new_joint_config = copy.deepcopy(joint_config) + if self.tokenizer_type == "agg": + new_joint_config["vocabulary"] = ListConfig(vocabulary) + else: + new_joint_config["vocabulary"] = ListConfig(list(vocabulary.keys())) + + new_joint_config['num_classes'] = len(vocabulary) + del self.joint + self.joint = EncDecHybridRNNTCTCBPEModel.from_config_dict(new_joint_config) + + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config.vocab_size = len(vocabulary) + del self.decoder + self.decoder = EncDecHybridRNNTCTCBPEModel.from_config_dict(new_decoder_config) + + del self.loss + self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) + + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(RNNTBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) + + self.wer = RNNTBPEWER( + decoding=self.decoding, + batch_dim_index=self.wer.batch_dim_index, + use_cer=self.wer.use_cer, + log_prediction=self.wer.log_prediction, + dist_sync_on_step=True, + ) + + # Setup fused Joint step + if self.joint.fuse_loss_wer or ( + self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0 + ): + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # Update config + with open_dict(self.cfg.joint): + self.cfg.joint = new_joint_config + + with open_dict(self.cfg.decoder): + self.cfg.decoder = new_decoder_config + + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed tokenizer of the RNNT decoder to {self.joint.vocabulary} vocabulary.") + + # set up the new tokenizer for the CTC decoder + if hasattr(self, 'ctc_decoder'): + ctc_decoder_config = copy.deepcopy(self.ctc_decoder.to_config_dict()) + # sidestepping the potential overlapping tokens issue in aggregate tokenizers + if self.tokenizer_type == "agg": + ctc_decoder_config.vocabulary = ListConfig(vocabulary) + else: + ctc_decoder_config.vocabulary = ListConfig(list(vocabulary.keys())) + + decoder_num_classes = ctc_decoder_config['num_classes'] + # Override number of classes if placeholder provided + logging.info( + "\nReplacing old number of classes ({}) with new number of classes - {}".format( + decoder_num_classes, len(vocabulary) + ) + ) + ctc_decoder_config['num_classes'] = len(vocabulary) + + del self.ctc_decoder + self.ctc_decoder = EncDecHybridRNNTCTCBPEModel.from_config_dict(ctc_decoder_config) + del self.ctc_loss + self.ctc_loss = CTCLoss( + num_classes=self.ctc_decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"), + ) + + if ctc_decoding_cfg is None: + # Assume same decoding config as before + ctc_decoding_cfg = self.cfg.aux_ctc.decoding + + # Assert the decoding config with all hyper parameters + ctc_decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig) + ctc_decoding_cls = OmegaConf.create(OmegaConf.to_container(ctc_decoding_cls)) + ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg) + + self.ctc_decoding = CTCBPEDecoding(decoding_cfg=ctc_decoding_cfg, tokenizer=self.tokenizer) + + self.ctc_wer = WERBPE( + decoding=self.ctc_decoding, + use_cer=self.cfg.aux_ctc.get('use_cer', False), + log_prediction=self.cfg.get("log_prediction", False), + dist_sync_on_step=True, + ) + + # Update config + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoder = ctc_decoder_config + + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoding = ctc_decoding_cfg + + logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") + + def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = None): + """ + Changes decoding strategy used during RNNT decoding process. + Args: + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + decoder_type: (str) Can be set to 'rnnt' or 'ctc' to switch between appropriate decoder in a + model having both RNN-T and CTC decoders. Defaults to None, in which case RNN-T decoder is + used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. + """ + if decoder_type is None or decoder_type == 'rnnt': + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(RNNTBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) + + self.wer = RNNTBPEWER( + decoding=self.decoding, + batch_dim_index=self.wer.batch_dim_index, + use_cer=self.wer.use_cer, + log_prediction=self.wer.log_prediction, + dist_sync_on_step=True, + ) + + # Setup fused Joint step + if self.joint.fuse_loss_wer or ( + self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0 + ): + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # Update config + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + elif decoder_type == 'ctc': + if not hasattr(self, 'ctc_decoding'): + raise ValueError("The model does not have the ctc_decoding module and does not support ctc decoding.") + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.aux_ctc.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer) + + self.ctc_wer = WERBPE( + decoding=self.ctc_decoding, + use_cer=self.ctc_wer.use_cer, + log_prediction=self.ctc_wer.log_prediction, + dist_sync_on_step=True, + ) + + # Update config + with open_dict(self.cfg.aux_ctc.decoding): + self.cfg.aux_ctc.decoding = decoding_cfg + + self.use_rnnt_decoder = False + logging.info( + f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" + ) + else: + raise ErrorValue(f"decoder_type={decoder_type} is not supported. Supported values: [ctc,rnnt]") + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + return results diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py new file mode 100644 index 000000000000..a393b65ccd79 --- /dev/null +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -0,0 +1,604 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os +import tempfile +from typing import List, Optional + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, CTCDecoding, CTCDecodingConfig +from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.classes.mixins import AccessMixin +from nemo.utils import logging, model_utils + + +class EncDecHybridRNNTCTCModel(EncDecRNNTModel, ASRBPEMixin): + """Base class for hybrid RNNT/CTC models.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + super().__init__(cfg=cfg, trainer=trainer) + + if 'aux_ctc' not in self.cfg: + raise ValueError( + "The config need to have a section for the CTC decoder named as aux_ctc for Hybrid models." + ) + with open_dict(self.cfg.aux_ctc): + if "feat_in" not in self.cfg.aux_ctc.decoder or ( + not self.cfg.aux_ctc.decoder.feat_in and hasattr(self.encoder, '_feat_out') + ): + self.cfg.aux_ctc.decoder.feat_in = self.encoder._feat_out + if "feat_in" not in self.cfg.aux_ctc.decoder or not self.cfg.aux_ctc.decoder.feat_in: + raise ValueError("param feat_in of the decoder's config is not set!") + + if self.cfg.aux_ctc.decoder.num_classes < 1 and self.cfg.aux_ctc.decoder.vocabulary is not None: + logging.info( + "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format( + self.cfg.aux_ctc.decoder.num_classes, len(self.cfg.aux_ctc.decoder.vocabulary) + ) + ) + self.cfg.aux_ctc.decoder["num_classes"] = len(self.cfg.aux_ctc.decoder.vocabulary) + + self.ctc_decoder = EncDecRNNTModel.from_config_dict(self.cfg.aux_ctc.decoder) + self.ctc_loss_weight = self.cfg.aux_ctc.get("ctc_loss_weight", 0.5) + + self.ctc_loss = CTCLoss( + num_classes=self.ctc_decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"), + ) + + ctc_decoding_cfg = self.cfg.aux_ctc.get('decoding', None) + if ctc_decoding_cfg is None: + ctc_decoding_cfg = OmegaConf.structured(CTCDecodingConfig) + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoding = ctc_decoding_cfg + + self.ctc_decoding = CTCDecoding(self.cfg.aux_ctc.decoding, vocabulary=self.ctc_decoder.vocabulary) + self.ctc_wer = WER( + decoding=self.ctc_decoding, + use_cer=self.cfg.aux_ctc.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self.cfg.get("log_prediction", False), + ) + + # setting the RNNT decoder as the default one + self.use_rnnt_decoder = True + + @torch.no_grad() + def transcribe( + self, + paths2audio_files: List[str], + batch_size: int = 4, + return_hypotheses: bool = False, + partial_hypothesis: Optional[List['Hypothesis']] = None, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + ) -> (List[str], Optional[List['Hypothesis']]): + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + num_workers: (int) number of workers for DataLoader + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + + Returns: + A list of transcriptions in the same order as paths2audio_files. Will also return + """ + if self.use_rnnt_decoder: + return super().transcribe(**kwargs) + + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + # We will store transcriptions here + hypotheses = [] + all_hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.decoder.freeze() + self.joint.freeze() + if hasattr(self, 'ctc_decoder'): + self.ctc_decoder.freeze() + + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''} + fp.write(json.dumps(entry) + '\n') + + config = { + 'paths2audio_files': paths2audio_files, + 'batch_size': batch_size, + 'temp_dir': tmpdir, + 'num_workers': num_workers, + 'channel_selector': channel_selector, + } + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + encoded, encoded_len = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + + logits = self.ctc_decoder(encoder_output=encoded) + best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor( + logits, encoded_len, return_hypotheses=return_hypotheses, + ) + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + best_hyp[idx].y_sequence = logits[idx][: logits_len[idx]] + if best_hyp[idx].alignments is None: + best_hyp[idx].alignments = best_hyp[idx].y_sequence + del logits + + hypotheses += best_hyp + if all_hyp is not None: + all_hypotheses += all_hyp + else: + all_hypotheses += best_hyp + + del encoded + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + + logging.set_verbosity(logging_level) + if mode is True: + self.encoder.unfreeze() + self.decoder.unfreeze() + self.joint.unfreeze() + if hasattr(self, 'ctc_decoder'): + self.ctc_decoder.unfreeze() + return hypotheses, all_hypotheses + + def change_vocabulary( + self, + new_vocabulary: List[str], + decoding_cfg: Optional[DictConfig] = None, + ctc_decoding_cfg: Optional[DictConfig] = None, + ): + """ + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + ctc_decoding_cfg: A config for CTC decoding, which is optional and can be used to change decoding type. + + Returns: None + + """ + super().change_vocabulary(new_vocabulary=new_vocabulary, decoding_cfg=decoding_cfg) + + # set up the new tokenizer for the CTC decoder + if hasattr(self, 'ctc_decoder'): + if self.ctc_decoder.vocabulary == new_vocabulary: + logging.warning( + f"Old {self.ctc_decoder.vocabulary} and new {new_vocabulary} match. Not changing anything." + ) + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.ctc_decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.ctc_decoder + self.ctc_decoder = EncDecHybridRNNTCTCModel.from_config_dict(new_decoder_config) + del self.ctc_loss + self.ctc_loss = CTCLoss( + num_classes=self.ctc_decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"), + ) + + if ctc_decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `ctc_decoding_cfg` passed when changing decoding strategy, using internal config") + ctc_decoding_cfg = self.cfg.aux_ctc.decoding + + # Assert the decoding config with all hyper parameters + ctc_decoding_cls = OmegaConf.structured(CTCDecodingConfig) + ctc_decoding_cls = OmegaConf.create(OmegaConf.to_container(ctc_decoding_cls)) + ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg) + + self.ctc_decoding = CTCDecoding(decoding_cfg=ctc_decoding_cfg, vocabulary=self.ctc_decoder.vocabulary) + + self.ctc_wer = WER( + decoding=self.ctc_decoding, + use_cer=self.ctc_wer.use_cer, + log_prediction=self.ctc_wer.log_prediction, + dist_sync_on_step=True, + ) + + # Update config + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoding = ctc_decoding_cfg + + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoder = new_decoder_config + + ds_keys = ['train_ds', 'validation_ds', 'test_ds'] + for key in ds_keys: + if key in self.cfg: + with open_dict(self.cfg[key]): + self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary) + + logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") + + def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = None): + """ + Changes decoding strategy used during RNNT decoding process. + + Args: + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + decoder_type: (str) Can be set to 'rnnt' or 'ctc' to switch between appropriate decoder in a + model having RNN-T and CTC decoders. Defaults to None, in which case RNN-T decoder is + used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. + """ + if decoder_type is None or decoder_type == 'rnnt': + self.use_rnnt_decoder = True + return super().change_decoding_strategy(decoding_cfg=decoding_cfg) + + assert decoder_type == 'ctc' and hasattr(self, 'ctc_decoder') + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.aux_ctc.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.ctc_decoding = CTCDecoding(decoding_cfg=decoding_cfg, vocabulary=self.ctc_decoder.vocabulary) + + self.ctc_wer = WER( + decoding=self.ctc_decoding, + use_cer=self.ctc_wer.use_cer, + log_prediction=self.ctc_wer.log_prediction, + dist_sync_on_step=True, + ) + + # Update config + with open_dict(self.cfg.aux_ctc): + self.cfg.aux_ctc.decoding = decoding_cfg + + self.use_rnnt_decoder = False + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}") + + # PTL-specific methods + def training_step(self, batch, batch_nb): + # Reset access registry + if AccessMixin.is_access_enabled(): + AccessMixin.reset_registry(self) + + signal, signal_len, transcript, transcript_len = batch + + # forward() only performs encoder forward + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + # During training, loss must be computed, so decoder forward is necessary + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_nb + + # If fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + # Compute full joint and loss + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + + # Reset access registry + if AccessMixin.is_access_enabled(): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if (sample_id + 1) % log_every_n_steps == 0: + self.wer.update(encoded, encoded_len, transcript, transcript_len) + _, scores, words = self.wer.compute() + self.wer.reset() + tensorboard_logs.update({'training_batch_wer': scores.float() / words}) + + else: + # If fused Joint-Loss-WER is used + if (sample_id + 1) % log_every_n_steps == 0: + compute_wer = True + else: + compute_wer = False + + # Fused joint step + loss_value, wer, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=compute_wer, + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + + # Reset access registry + if AccessMixin.is_access_enabled(): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if compute_wer: + tensorboard_logs.update({'training_batch_wer': wer}) + + if self.ctc_loss_weight > 0: + log_probs = self.ctc_decoder(encoder_output=encoded) + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + tensorboard_logs['train_rnnt_loss'] = loss_value + tensorboard_logs['train_ctc_loss'] = ctc_loss + loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss + tensorboard_logs['train_loss'] = loss_value + if (sample_id + 1) % log_every_n_steps == 0: + self.ctc_wer.update( + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, + ) + ctc_wer, _, _ = self.ctc_wer.compute() + self.ctc_wer.reset() + tensorboard_logs.update({'training_batch_wer_ctc': ctc_wer}) + + # Log items + self.log_dict(tensorboard_logs) + + # Preserve batch acoustic model T and language model U parameters if normalizing + if self._optim_normalize_joint_txu: + self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()] + + return {'loss': loss_value} + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + # TODO: add support for CTC decoding + signal, signal_len, transcript, transcript_len, sample_id = batch + + # forward() only performs encoder forward + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) + + sample_id = sample_id.cpu().detach().numpy() + return list(zip(sample_id, best_hyp_text)) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len = batch + + # forward() only performs encoder forward + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + tensorboard_logs = {} + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + tensorboard_logs['val_loss'] = loss_value + + self.wer.update(encoded, encoded_len, transcript, transcript_len) + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + else: + # If experimental fused Joint-Loss-WER is used + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + # Fused joint step + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + ) + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + log_probs = self.ctc_decoder(encoder_output=encoded) + if self.compute_eval_loss: + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + tensorboard_logs['val_ctc_loss'] = ctc_loss + tensorboard_logs['val_rnnt_loss'] = loss_value + loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss + tensorboard_logs['val_loss'] = loss_value + self.ctc_wer.update( + predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + ) + ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() + self.ctc_wer.reset() + tensorboard_logs['val_wer_num_ctc'] = ctc_wer_num + tensorboard_logs['val_wer_denom_ctc'] = ctc_wer_denom + tensorboard_logs['val_wer_ctc'] = ctc_wer + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return tensorboard_logs + + def test_step(self, batch, batch_idx, dataloader_idx=0): + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx) + test_logs = { + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + # 'test_wer': logs['val_wer'], + } + if 'val_loss' in logs: + test_logs['test_loss'] = logs['val_loss'] + + if self.ctc_loss_weight > 0: + test_logs['test_wer_num_ctc'] = logs['val_wer_num_ctc'] + test_logs['test_wer_denom_ctc'] = logs['val_wer_denom_ctc'] + if 'val_ctc_loss' in logs: + test_logs['test_ctc_loss'] = logs['val_ctc_loss'] + if 'val_rnnt_loss' in logs: + test_logs['test_rnnt_loss'] = logs['val_rnnt_loss'] + + return test_logs + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + if self.compute_eval_loss: + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_loss_log = {'val_loss': val_loss_mean} + else: + val_loss_log = {} + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**val_loss_log, 'val_wer': wer_num.float() / wer_denom} + if self.ctc_loss_weight > 0: + ctc_wer_num = torch.stack([x['val_wer_num_ctc'] for x in outputs]).sum() + ctc_wer_denom = torch.stack([x['val_wer_denom_ctc'] for x in outputs]).sum() + tensorboard_logs['val_wer_ctc'] = ctc_wer_num.float() / ctc_wer_denom + return {**val_loss_log, 'log': tensorboard_logs} + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + if self.compute_eval_loss: + test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + test_loss_log = {'test_loss': test_loss_mean} + else: + test_loss_log = {} + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**test_loss_log, 'test_wer': wer_num.float() / wer_denom} + + if self.ctc_loss_weight > 0: + ctc_wer_num = torch.stack([x['test_wer_num_ctc'] for x in outputs]).sum() + ctc_wer_denom = torch.stack([x['test_wer_denom_ctc'] for x in outputs]).sum() + tensorboard_logs['test_wer_ctc'] = ctc_wer_num.float() / ctc_wer_denom + + return {**test_loss_log, 'log': tensorboard_logs} + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + return results diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index a01b8a64bd01..64850d53587f 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -512,7 +512,7 @@ def setup_streaming_params( """ streaming_cfg = CacheAwareStreamingConfig() if chunk_size is not None: - if chunk_size <= 1: + if chunk_size < 1: raise ValueError("chunk_size needs to be a number larger or equal to one.") lookahead_steps = chunk_size - 1 streaming_cfg.cache_drop_size = chunk_size - shift_size diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py new file mode 100644 index 000000000000..e59353102c39 --- /dev/null +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -0,0 +1,309 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile + +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.metrics.wer_bpe import CTCBPEDecoding, CTCBPEDecodingConfig +from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode +from nemo.collections.common import tokenizers +from nemo.core.utils import numba_utils +from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ + +NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cpu_is_supported( + __NUMBA_MINIMUM_VERSION__ +) or numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) + + +@pytest.fixture() +def hybrid_asr_model(test_data_dir): + preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} + + model_defaults = {'enc_hidden': 1024, 'pred_hidden': 64} + + encoder = { + 'cls': 'nemo.collections.asr.modules.ConvASREncoder', + 'params': { + 'feat_in': 64, + 'activation': 'relu', + 'conv_mask': True, + 'jasper': [ + { + 'filters': model_defaults['enc_hidden'], + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + } + ], + }, + } + + decoder = { + '_target_': 'nemo.collections.asr.modules.RNNTDecoder', + 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,}, + } + + joint = { + '_target_': 'nemo.collections.asr.modules.RNNTJoint', + 'jointnet': {'joint_hidden': 32, 'activation': 'relu',}, + } + + decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} + + tokenizer = {'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe'} + + loss = {'loss_name': 'default', 'warprnnt_numba_kwargs': {'fastemit_lambda': 0.001}} + + aux_ctc = { + 'ctc_loss_weight': 0.3, + 'use_cer': False, + 'ctc_reduction': 'mean_batch', + 'decoder': { + '_target_': 'nemo.collections.asr.modules.ConvASRDecoder', + 'feat_in': 1024, + 'num_classes': -2, + 'vocabulary': None, + }, + 'decoding': DictConfig(CTCBPEDecodingConfig), + } + + modelConfig = DictConfig( + { + 'preprocessor': DictConfig(preprocessor), + 'model_defaults': DictConfig(model_defaults), + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + 'joint': DictConfig(joint), + 'tokenizer': DictConfig(tokenizer), + 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), + 'aux_ctc': DictConfig(aux_ctc), + } + ) + + model_instance = EncDecHybridRNNTCTCBPEModel(cfg=modelConfig) + return model_instance + + +class TestEncDecHybridRNNTCTCBPEModel: + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_constructor(self, hybrid_asr_model): + hybrid_asr_model.train() + # TODO: make proper config and assert correct number of weights + # Check to/from config_dict: + confdict = hybrid_asr_model.to_config_dict() + instance2 = EncDecHybridRNNTCTCBPEModel.from_config_dict(confdict) + assert isinstance(instance2, EncDecHybridRNNTCTCBPEModel) + + @pytest.mark.with_downloads() + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_forward(self, hybrid_asr_model): + hybrid_asr_model = hybrid_asr_model.eval() + + hybrid_asr_model.preprocessor.featurizer.dither = 0.0 + hybrid_asr_model.preprocessor.featurizer.pad_to = 0 + + hybrid_asr_model.compute_eval_loss = False + + input_signal = torch.randn(size=(4, 512)) + length = torch.randint(low=161, high=500, size=[4]) + + with torch.no_grad(): + # batch size 1 + logprobs_instance = [] + for i in range(input_signal.size(0)): + logprobs_ins, _ = hybrid_asr_model.forward( + input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1] + ) + logprobs_instance.append(logprobs_ins) + logits_instance = torch.cat(logprobs_instance, 0) + + # batch size 4 + logprobs_batch, _ = hybrid_asr_model.forward(input_signal=input_signal, input_signal_length=length) + + assert logits_instance.shape == logprobs_batch.shape + diff = torch.mean(torch.abs(logits_instance - logprobs_batch)) + assert diff <= 1e-6 + diff = torch.max(torch.abs(logits_instance - logprobs_batch)) + assert diff <= 1e-6 + + @pytest.mark.with_downloads() + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_save_restore_artifact(self, hybrid_asr_model): + hybrid_asr_model.train() + + with tempfile.TemporaryDirectory() as tmp_dir: + path = os.path.join(tmp_dir, 'rnnt_bpe.nemo') + hybrid_asr_model.save_to(path) + + new_model = EncDecHybridRNNTCTCBPEModel.restore_from(path) + assert isinstance(new_model, type(hybrid_asr_model)) + assert new_model.vocab_path.endswith('_vocab.txt') + + assert len(new_model.tokenizer.tokenizer.get_vocab()) == 128 + + @pytest.mark.with_downloads() + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_save_restore_artifact_spe(self, hybrid_asr_model, test_data_dir): + hybrid_asr_model.train() + + with tempfile.TemporaryDirectory() as tmpdir: + tokenizer_dir = os.path.join(test_data_dir, "asr", "tokenizers", "an4_spe_128") + hybrid_asr_model.change_vocabulary(new_tokenizer_dir=tokenizer_dir, new_tokenizer_type='bpe') + + save_path = os.path.join(tmpdir, 'ctc_bpe.nemo') + hybrid_asr_model.train() + hybrid_asr_model.save_to(save_path) + + new_model = EncDecHybridRNNTCTCBPEModel.restore_from(save_path) + assert isinstance(new_model, type(hybrid_asr_model)) + assert isinstance(new_model.tokenizer, tokenizers.SentencePieceTokenizer) + assert new_model.model_path.endswith('_tokenizer.model') + assert new_model.vocab_path.endswith('_vocab.txt') + assert new_model.spe_vocab_path.endswith('_tokenizer.vocab') + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_save_restore_artifact_agg(self, hybrid_asr_model, test_data_dir): + tokenizer_dir = os.path.join(test_data_dir, "asr", "tokenizers", "an4_spe_128") + tok_en = {"dir": tokenizer_dir, "type": "wpe"} + # the below is really an english tokenizer but we pretend it is spanish + tok_es = {"dir": tokenizer_dir, "type": "wpe"} + tcfg = DictConfig({"type": "agg", "langs": {"en": tok_en, "es": tok_es}}) + with tempfile.TemporaryDirectory() as tmpdir: + hybrid_asr_model.change_vocabulary(new_tokenizer_dir=tcfg, new_tokenizer_type="agg") + + save_path = os.path.join(tmpdir, "ctc_agg.nemo") + hybrid_asr_model.train() + hybrid_asr_model.save_to(save_path) + + new_model = EncDecHybridRNNTCTCBPEModel.restore_from(save_path) + assert isinstance(new_model, type(hybrid_asr_model)) + assert isinstance(new_model.tokenizer, tokenizers.AggregateTokenizer) + + # should be double + assert new_model.tokenizer.tokenizer.vocab_size == 254 + assert len(new_model.tokenizer.tokenizer.get_vocab()) == 254 + + @pytest.mark.with_downloads() + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_vocab_change(self, test_data_dir, hybrid_asr_model): + with tempfile.TemporaryDirectory() as tmpdir: + old_tokenizer_dir = os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128", 'vocab.txt') + new_tokenizer_dir = os.path.join(tmpdir, 'tokenizer') + + os.makedirs(new_tokenizer_dir, exist_ok=True) + shutil.copy2(old_tokenizer_dir, new_tokenizer_dir) + + nw1 = hybrid_asr_model.num_weights + hybrid_asr_model.change_vocabulary(new_tokenizer_dir=new_tokenizer_dir, new_tokenizer_type='wpe') + # No change + assert nw1 == hybrid_asr_model.num_weights + + with open(os.path.join(new_tokenizer_dir, 'vocab.txt'), 'a+') as f: + f.write("!\n") + f.write('$\n') + f.write('@\n') + + hybrid_asr_model.change_vocabulary(new_tokenizer_dir=new_tokenizer_dir, new_tokenizer_type='wpe') + + # rnn embedding + joint + bias + pred_embedding = 3 * (hybrid_asr_model.decoder.pred_hidden) + joint_joint = 3 * (hybrid_asr_model.joint.joint_hidden + 1) + ctc_decoder = 3 * (hybrid_asr_model.ctc_decoder._feat_in + 1) + assert hybrid_asr_model.num_weights == (nw1 + (pred_embedding + joint_joint) + ctc_decoder) + + @pytest.mark.with_downloads() + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_decoding_change(self, hybrid_asr_model): + assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyBatchedRNNTInfer) + + new_strategy = DictConfig({}) + new_strategy.strategy = 'greedy' + new_strategy.greedy = DictConfig({'max_symbols': 10}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyRNNTInfer) + + new_strategy = DictConfig({}) + new_strategy.strategy = 'beam' + new_strategy.beam = DictConfig({'beam_size': 1}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "default" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'beam' + new_strategy.beam = DictConfig({'beam_size': 2}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "default" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'tsd' + new_strategy.beam = DictConfig({'beam_size': 2}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "tsd" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'alsd' + new_strategy.beam = DictConfig({'beam_size': 2}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "alsd" + + assert hybrid_asr_model.ctc_decoding is not None + assert isinstance(hybrid_asr_model.ctc_decoding, CTCBPEDecoding) + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy" + assert hybrid_asr_model.ctc_decoding.preserve_alignments is False + assert hybrid_asr_model.ctc_decoding.compute_timestamps is False + + cfg = CTCBPEDecodingConfig(preserve_alignments=True, compute_timestamps=True) + hybrid_asr_model.change_decoding_strategy(cfg, decoder_type="ctc") + + assert hybrid_asr_model.ctc_decoding.preserve_alignments is True + assert hybrid_asr_model.ctc_decoding.compute_timestamps is True + assert hybrid_asr_model.use_rnnt_decoder is False diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py new file mode 100644 index 000000000000..22926b6516ee --- /dev/null +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -0,0 +1,644 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +import pytest +import torch +from omegaconf import DictConfig, ListConfig + +from nemo.collections.asr.metrics.wer import CTCDecoding, CTCDecodingConfig +from nemo.collections.asr.models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.core.utils import numba_utils +from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ +from nemo.utils.config_utils import assert_dataclass_signature_match + +NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cpu_is_supported( + __NUMBA_MINIMUM_VERSION__ +) or numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) + + +@pytest.fixture() +def hybrid_asr_model(): + preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} + + # fmt: off + labels = [' ', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', + 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', + 'x', 'y', 'z', "'", + ] + # fmt: on + + model_defaults = {'enc_hidden': 1024, 'pred_hidden': 64} + + encoder = { + 'cls': 'nemo.collections.asr.modules.ConvASREncoder', + 'params': { + 'feat_in': 64, + 'activation': 'relu', + 'conv_mask': True, + 'jasper': [ + { + 'filters': model_defaults['enc_hidden'], + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + } + ], + }, + } + + decoder = { + '_target_': 'nemo.collections.asr.modules.RNNTDecoder', + 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1}, + } + + joint = { + '_target_': 'nemo.collections.asr.modules.RNNTJoint', + 'jointnet': {'joint_hidden': 32, 'activation': 'relu'}, + } + + decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} + + loss = {'loss_name': 'default', 'warprnnt_numba_kwargs': {'fastemit_lambda': 0.001}} + + aux_ctc = { + 'ctc_loss_weight': 0.3, + 'use_cer': False, + 'ctc_reduction': 'mean_batch', + 'decoder': { + '_target_': 'nemo.collections.asr.modules.ConvASRDecoder', + 'feat_in': 1024, + 'num_classes': len(labels), + 'vocabulary': labels, + }, + 'decoding': DictConfig(CTCDecodingConfig), + } + + modelConfig = DictConfig( + { + 'labels': ListConfig(labels), + 'preprocessor': DictConfig(preprocessor), + 'model_defaults': DictConfig(model_defaults), + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + 'joint': DictConfig(joint), + 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), + 'aux_ctc': DictConfig(aux_ctc), + } + ) + + model_instance = EncDecHybridRNNTCTCModel(cfg=modelConfig) + return model_instance + + +class TestEncDecHybridRNNTCTCModel: + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_constructor(self, hybrid_asr_model): + hybrid_asr_model.train() + # TODO: make proper config and assert correct number of weights + # Check to/from config_dict: + confdict = hybrid_asr_model.to_config_dict() + instance2 = EncDecHybridRNNTCTCModel.from_config_dict(confdict) + assert isinstance(instance2, EncDecHybridRNNTCTCModel) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_forward(self, hybrid_asr_model): + hybrid_asr_model = hybrid_asr_model.eval() + + hybrid_asr_model.preprocessor.featurizer.dither = 0.0 + hybrid_asr_model.preprocessor.featurizer.pad_to = 0 + + hybrid_asr_model.compute_eval_loss = False + + input_signal = torch.randn(size=(4, 512)) + length = torch.randint(low=161, high=500, size=[4]) + + with torch.no_grad(): + # batch size 1 + logprobs_instance = [] + for i in range(input_signal.size(0)): + logprobs_ins, _ = hybrid_asr_model.forward( + input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1] + ) + logprobs_instance.append(logprobs_ins) + logprobs_instance = torch.cat(logprobs_instance, 0) + + # batch size 4 + logprobs_batch, _ = hybrid_asr_model.forward(input_signal=input_signal, input_signal_length=length) + + assert logprobs_instance.shape == logprobs_batch.shape + diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch)) + assert diff <= 1e-6 + diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) + assert diff <= 1e-6 + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_vocab_change(self, hybrid_asr_model): + old_vocab = copy.deepcopy(hybrid_asr_model.joint.vocabulary) + nw1 = hybrid_asr_model.num_weights + hybrid_asr_model.change_vocabulary(new_vocabulary=old_vocab) + # No change + assert nw1 == hybrid_asr_model.num_weights + new_vocab = copy.deepcopy(old_vocab) + new_vocab.append('!') + new_vocab.append('$') + new_vocab.append('@') + hybrid_asr_model.change_vocabulary(new_vocabulary=new_vocab) + # fully connected + bias + # rnn embedding + joint + bias + pred_embedding = 3 * (hybrid_asr_model.decoder.pred_hidden) + joint_joint = 3 * (hybrid_asr_model.joint.joint_hidden + 1) + ctc_decoder = 3 * (hybrid_asr_model.ctc_decoder._feat_in + 1) + assert hybrid_asr_model.num_weights == (nw1 + (pred_embedding + joint_joint) + ctc_decoder) + assert hybrid_asr_model.ctc_decoder.vocabulary == hybrid_asr_model.joint.vocabulary + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + def test_decoding_change(self, hybrid_asr_model): + assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyBatchedRNNTInfer) + + new_strategy = DictConfig({}) + new_strategy.strategy = 'greedy' + new_strategy.greedy = DictConfig({'max_symbols': 10}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyRNNTInfer) + + new_strategy = DictConfig({}) + new_strategy.strategy = 'beam' + new_strategy.beam = DictConfig({'beam_size': 1}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "default" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'beam' + new_strategy.beam = DictConfig({'beam_size': 2}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "default" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'tsd' + new_strategy.beam = DictConfig({'beam_size': 2}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "tsd" + + new_strategy = DictConfig({}) + new_strategy.strategy = 'alsd' + new_strategy.beam = DictConfig({'beam_size': 2}) + hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy) + assert isinstance(hybrid_asr_model.decoding.decoding, beam_decode.BeamRNNTInfer) + assert hybrid_asr_model.decoding.decoding.search_type == "alsd" + + assert hybrid_asr_model.ctc_decoding is not None + assert isinstance(hybrid_asr_model.ctc_decoding, CTCDecoding) + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy" + assert hybrid_asr_model.ctc_decoding.preserve_alignments is False + assert hybrid_asr_model.ctc_decoding.compute_timestamps is False + + cfg = CTCDecodingConfig(preserve_alignments=True, compute_timestamps=True) + hybrid_asr_model.change_decoding_strategy(cfg, decoder_type="ctc") + + assert hybrid_asr_model.ctc_decoding.preserve_alignments is True + assert hybrid_asr_model.ctc_decoding.compute_timestamps is True + + @pytest.mark.unit + def test_GreedyRNNTInferConfig(self): + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + + result = assert_dataclass_signature_match( + greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS + ) + + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None + + @pytest.mark.unit + def test_GreedyBatchedRNNTInferConfig(self): + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + + result = assert_dataclass_signature_match( + greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS + ) + + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None + + @pytest.mark.unit + def test_BeamRNNTInferConfig(self): + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + + result = assert_dataclass_signature_match( + beam_decode.BeamRNNTInfer, beam_decode.BeamRNNTInferConfig, ignore_args=IGNORE_ARGS + ) + + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ) + def test_greedy_decoding(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer], + ) + def test_greedy_multi_decoding(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + (partial_hyp) = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + partial_hyp = partial_hyp[0] + _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len, partial_hypotheses=partial_hyp) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ) + def test_greedy_decoding_stateless_decoder(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = StatelessTransducerDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer], + ) + def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = StatelessTransducerDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + (partial_hyp) = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + partial_hyp = partial_hyp[0] + _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len, partial_hypotheses=partial_hyp) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ) + def test_greedy_decoding_preserve_alignment(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + greedy = greedy_class( + decoder, joint_net, blank_index=len(token_list) - 1, preserve_alignments=True, max_symbols_per_step=5 + ) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + hyp = greedy(encoder_output=enc_out, encoded_lengths=enc_len)[0][0] # type: rnnt_utils.Hypothesis + assert hyp.alignments is not None + + for t in range(len(hyp.alignments)): + for u in range(len(hyp.alignments[t])): + logp, label = hyp.alignments[t][u] + assert torch.is_tensor(logp) + assert torch.is_tensor(label) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + {"search_type": "greedy"}, + {"search_type": "default", "score_norm": False, "return_best_hypothesis": False}, + {"search_type": "alsd", "alsd_max_target_len": 20, "return_best_hypothesis": False}, + {"search_type": "tsd", "tsd_max_sym_exp_per_step": 3, "return_best_hypothesis": False}, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 2, "return_best_hypothesis": False}, + {"search_type": "maes", "maes_num_steps": 3, "maes_expansion_beta": 1, "return_best_hypothesis": False}, + ], + ) + def test_beam_decoding(self, beam_config): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + beam_size = 1 if beam_config["search_type"] == "greedy" else 2 + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + beam = beam_decode.BeamRNNTInfer(decoder, joint_net, beam_size=beam_size, **beam_config,) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + _ = beam(encoder_output=enc_out, encoded_lengths=enc_len) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [{"search_type": "greedy"}, {"search_type": "default", "score_norm": False, "return_best_hypothesis": False},], + ) + def test_beam_decoding_preserve_alignments(self, beam_config): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + beam_size = 1 if beam_config["search_type"] == "greedy" else 2 + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + + beam = beam_decode.BeamRNNTInfer( + decoder, joint_net, beam_size=beam_size, **beam_config, preserve_alignments=True + ) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + hyp = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0][0] # type: rnnt_utils.Hypothesis + + if isinstance(hyp, rnnt_utils.NBestHypotheses): + hyp = hyp.n_best_hypotheses[0] # select top hypothesis only + + assert hyp.alignments is not None + + for t in range(len(hyp.alignments)): + for u in range(len(hyp.alignments[t])): + logp, label = hyp.alignments[t][u] + assert torch.is_tensor(logp) + assert torch.is_tensor(label) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ) + def test_greedy_decoding_SampledRNNTJoint(self, greedy_class): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list) + + greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + {"search_type": "greedy"}, + {"search_type": "default", "score_norm": False, "return_best_hypothesis": False}, + {"search_type": "alsd", "alsd_max_target_len": 20, "return_best_hypothesis": False}, + {"search_type": "tsd", "tsd_max_sym_exp_per_step": 3, "return_best_hypothesis": False}, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 2, "return_best_hypothesis": False}, + {"search_type": "maes", "maes_num_steps": 3, "maes_expansion_beta": 1, "return_best_hypothesis": False}, + ], + ) + def test_beam_decoding_SampledRNNTJoint(self, beam_config): + token_list = [" ", "a", "b", "c"] + vocab_size = len(token_list) + beam_size = 1 if beam_config["search_type"] == "greedy" else 2 + + encoder_output_size = 4 + decoder_output_size = 4 + joint_output_shape = 4 + + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + jointnet_cfg = { + 'encoder_hidden': encoder_output_size, + 'pred_hidden': decoder_output_size, + 'joint_hidden': joint_output_shape, + 'activation': 'relu', + } + + decoder = RNNTDecoder(prednet_cfg, vocab_size) + joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list) + + beam = beam_decode.BeamRNNTInfer(decoder, joint_net, beam_size=beam_size, **beam_config,) + + # (B, D, T) + enc_out = torch.randn(1, encoder_output_size, 30) + enc_len = torch.tensor([30], dtype=torch.int32) + + with torch.no_grad(): + _ = beam(encoder_output=enc_out, encoded_lengths=enc_len)