-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Hybrid RNNT-CTC model (#5364)
* added initial code. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added the confs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added the confs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changed name from joint to hybrid. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed format. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed format. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * addressed comments. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * addressed comments. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added docs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added docs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added docs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added docs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * addec CI test. Signed-off-by: Vahid <vnoroozi@nvidia.com> * addec CI test. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bugs in change_vocabs. Signed-off-by: vahidoox <vnoroozi@nvidia.com> * fixed bugs in change_vocabs. Signed-off-by: vahidoox <vnoroozi@nvidia.com> * fixed style. Signed-off-by: vahidoox <vnoroozi@nvidia.com> * fixed style. Signed-off-by: vahidoox <vnoroozi@nvidia.com> * fixed style. Signed-off-by: vahidoox <vnoroozi@nvidia.com> * raise error for aux_ctc. Signed-off-by: Vahid <vnoroozi@nvidia.com> * raise error for aux_ctc. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise error for aux_ctc. Signed-off-by: Vahid <vnoroozi@nvidia.com> * raise error for aux_ctc. Signed-off-by: Vahid <vnoroozi@nvidia.com> * updated the streaming names. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added unittests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added unittests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added unittests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added methods. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added decoding. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fxied the tests. Signed-off-by: Vahid <vnoroozi@nvidia.com> Signed-off-by: Vahid <vnoroozi@nvidia.com> Signed-off-by: vahidoox <vnoroozi@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
ced7133
commit 786a850
Showing
17 changed files
with
2,787 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# ASR with Hybrid Transducer/CTC Models | ||
|
||
This directory contains example scripts to train ASR models with two decoders of Transducer and CTC Loss. | ||
|
||
Currently supported models are - | ||
|
||
* Character based Hybrid RNNT/CTC model | ||
* Subword based Hybrid RNNT/CTC model | ||
|
||
# Model execution overview | ||
|
||
The training scripts in this directory execute in the following order. When preparing your own training-from-scratch / fine-tuning scripts, please follow this order for correct training/inference. | ||
|
||
```mermaid | ||
graph TD | ||
A[Hydra Overrides + Yaml Config] --> B{Config} | ||
B --> |Init| C[Trainer] | ||
C --> D[ExpManager] | ||
B --> D[ExpManager] | ||
C --> E[Model] | ||
B --> |Init| E[Model] | ||
E --> |Constructor| F1(Change Vocabulary) | ||
F1 --> F2(Setup Adapters if available) | ||
F2 --> G(Setup Train + Validation + Test Data loaders) | ||
G --> H1(Setup Optimization) | ||
H1 --> H2(Change Transducer Decoding Strategy) | ||
H2 --> I[Maybe init from pretrained] | ||
I --> J["trainer.fit(model)"] | ||
``` | ||
|
||
During restoration of the model, you may pass the Trainer to the restore_from / from_pretrained call, or set it after the model has been initialized by using `model.set_trainer(Trainer)`. |
91 changes: 91 additions & 0 deletions
91
examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
# Preparing the Tokenizer for the dataset | ||
Use the `process_asr_text_tokenizer.py` script under <NEMO_ROOT>/scripts/tokenizers/ in order to prepare the tokenizer. | ||
```sh | ||
python <NEMO_ROOT>/scripts/tokenizers/process_asr_text_tokenizer.py \ | ||
--manifest=<path to train manifest files, seperated by commas> | ||
OR | ||
--data_file=<path to text data, seperated by commas> \ | ||
--data_root="<output directory>" \ | ||
--vocab_size=<number of tokens in vocabulary> \ | ||
--tokenizer=<"spe" or "wpe"> \ | ||
--no_lower_case \ | ||
--spe_type=<"unigram", "bpe", "char" or "word"> \ | ||
--spe_character_coverage=1.0 \ | ||
--log | ||
``` | ||
# Training the model | ||
```sh | ||
python speech_to_text_hybrid_rnnt_ctc_bpe.py \ | ||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \ | ||
model.train_ds.manifest_filepath=<path to train manifest> \ | ||
model.validation_ds.manifest_filepath=<path to val/test manifest> \ | ||
model.tokenizer.dir=<path to directory of tokenizer (not full path to the vocab file!)> \ | ||
model.tokenizer.type=<either bpe or wpe> \ | ||
model.aux_ctc.ctc_loss_weight=0.3 \ | ||
trainer.devices=-1 \ | ||
trainer.max_epochs=100 \ | ||
model.optim.name="adamw" \ | ||
model.optim.lr=0.001 \ | ||
model.optim.betas=[0.9,0.999] \ | ||
model.optim.weight_decay=0.0001 \ | ||
model.optim.sched.warmup_steps=2000 | ||
exp_manager.create_wandb_logger=True \ | ||
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \ | ||
exp_manager.wandb_logger_kwargs.project="<Name of project>" | ||
``` | ||
# Fine-tune a model | ||
For documentation on fine-tuning this model, please visit - | ||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations | ||
""" | ||
|
||
import pytorch_lightning as pl | ||
from omegaconf import OmegaConf | ||
|
||
from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import exp_manager | ||
|
||
|
||
@hydra_runner( | ||
config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc_bpe" | ||
) | ||
def main(cfg): | ||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | ||
|
||
trainer = pl.Trainer(**cfg.trainer) | ||
exp_manager(trainer, cfg.get("exp_manager", None)) | ||
asr_model = EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer) | ||
|
||
# Initialize the weights of the model from another model, if provided via config | ||
asr_model.maybe_init_from_pretrained_checkpoint(cfg) | ||
|
||
trainer.fit(asr_model) | ||
|
||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: | ||
if asr_model.prepare_test(trainer): | ||
trainer.test(asr_model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() # noqa pylint: disable=no-value-for-parameter |
100 changes: 100 additions & 0 deletions
100
examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
# Training the model | ||
Basic run (on CPU for 50 epochs): | ||
python examples/asr/asr_transducer/speech_to_text_hybrid_rnnt_ctc.py \ | ||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \ | ||
model.train_ds.manifest_filepath="<path to manifest file>" \ | ||
model.validation_ds.manifest_filepath="<path to manifest file>" \ | ||
trainer.devices=1 \ | ||
trainer.accelerator='cpu' \ | ||
trainer.max_epochs=50 | ||
Add PyTorch Lightning Trainer arguments from CLI: | ||
python speech_to_text_rnnt.py \ | ||
... \ | ||
+trainer.fast_dev_run=true | ||
Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)" | ||
PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)" | ||
Override some args of optimizer: | ||
python speech_to_text_hybrid_rnnt_ctc.py \ | ||
--config-path="../conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc" \ | ||
--config-name="config_rnnt" \ | ||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \ | ||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ | ||
trainer.devices=2 \ | ||
model.aux_ctc.ctc_loss_weight=0.3 \ | ||
trainer.precision=16 \ | ||
trainer.max_epochs=2 \ | ||
model.optim.betas=[0.8,0.5] \ | ||
model.optim.weight_decay=0.0001 | ||
Override optimizer entirely | ||
python speech_to_text_hybrid_rnnt_ctc.py \ | ||
--config-path="../conf/conformer/hybrid_transducer_ctc/conformer_hybrid_transducer_ctc" \ | ||
--config-name="config_rnnt" \ | ||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \ | ||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ | ||
model.aux_ctc.ctc_loss_weight=0.3 \ | ||
trainer.devices=2 \ | ||
trainer.precision=16 \ | ||
trainer.max_epochs=2 \ | ||
model.optim.name=adamw \ | ||
model.optim.lr=0.001 \ | ||
~model.optim.args \ | ||
+model.optim.args.betas=[0.8,0.5]\ | ||
+model.optim.args.weight_decay=0.0005 | ||
# Fine-tune a model | ||
For documentation on fine-tuning this model, please visit - | ||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations | ||
""" | ||
|
||
import pytorch_lightning as pl | ||
from omegaconf import OmegaConf | ||
|
||
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import exp_manager | ||
|
||
|
||
@hydra_runner(config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc") | ||
def main(cfg): | ||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | ||
|
||
trainer = pl.Trainer(**cfg.trainer) | ||
exp_manager(trainer, cfg.get("exp_manager", None)) | ||
asr_model = EncDecHybridRNNTCTCModel(cfg=cfg.model, trainer=trainer) | ||
|
||
# Initialize the weights of the model from another model, if provided via config | ||
asr_model.maybe_init_from_pretrained_checkpoint(cfg) | ||
|
||
trainer.fit(asr_model) | ||
|
||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: | ||
if asr_model.prepare_test(trainer): | ||
trainer.test(asr_model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() # noqa pylint: disable=no-value-for-parameter |
File renamed without changes.
File renamed without changes.
Oops, something went wrong.