Skip to content

Commit

Permalink
st standalone model (#6969)
Browse files Browse the repository at this point in the history
* st standalone model

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* style fix

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* sacrebleu import fix, unused imports removed

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* import guard for nlp inside asr transformer bpe model

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* codeql fixes

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* comments answered

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* import ordering fix

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* yttm for asr removed

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* logging added

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* added inference and translate method

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AlexGrinch and pre-commit-ci[bot] authored Jul 18, 2023
1 parent 84ae944 commit 47e782a
Show file tree
Hide file tree
Showing 6 changed files with 1,114 additions and 1 deletion.
218 changes: 218 additions & 0 deletions examples/asr/conf/speech_translation/fast-conformer_transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence

name: "FastConformer-Transformer-BPE-st"

# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy
init_from_nemo_model:
model0:
path: ???
include: ["preprocessor", "encoder"]

model:
sample_rate: 16000
label_smoothing: 0.0
log_prediction: true # enables logging sample predictions in the output during training

train_ds:
is_tarred: true
tarred_audio_filepaths: ???
manifest_filepath: ???
sample_rate: 16000
shuffle: false
trim_silence: false
batch_size: 4
num_workers: 8

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true

test_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 17
d_model: 512

# Sub-sampling params
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 8 # must be power of 2
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false
reduction: null
reduction_position: null
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: batch_norm
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

transf_encoder:
num_layers: 0
hidden_size: 512
inner_size: 2048
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1

transf_decoder:
library: nemo
model_name: null
pretrained: false
max_sequence_length: 512
num_token_types: 0
embedding_dropout: 0.1
learn_positional_encodings: false
hidden_size: 512
inner_size: 2048
num_layers: 6
num_attention_heads: 4
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
hidden_act: relu
pre_ln: true
pre_ln_final_layer_norm: true

head:
num_layers: 1
activation: relu
log_softmax: true
dropout: 0.0
use_transformer_init: true

beam_search:
beam_size: 4
len_pen: 0.0
max_generation_delta: 50

optim:
name: adam
lr: 0.0001
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
#weight_decay: 1e-3

# scheduler setup
sched:
name: InverseSquareRootAnnealing
#d_model: ${model.encoder.d_model}
# scheduler config override
warmup_steps: 1000
warmup_ratio: null
min_lr: 1e-6

trainer:
gpus: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 100
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 100 # Interval of logging.
enable_progress_bar: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_sacreBLEU"
mode: "max"
save_top_k: 3
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
70 changes: 70 additions & 0 deletions examples/asr/speech_translation/speech_to_text_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
# Training the model
```sh
python speech_to_text_transformer.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.audio.tarred_audio_filepaths=<path to tar files with audio> \
model.train_ds.audio_manifest_filepath=<path to audio data manifest> \
model.validation_ds.manifest_filepath=<path to validation manifest> \
model.test_ds.manifest_filepath=<path to test manifest> \
model.tokenizer.dir=<path to directory of tokenizer (not full path to the vocab file!)> \
model.tokenizer.model_path=<path to speech tokenizer model> \
model.tokenizer.type=<either bpe, wpe, or yttm> \
trainer.gpus=-1 \
trainer.accelerator="ddp" \
trainer.max_epochs=100 \
model.optim.name="adamw" \
model.optim.lr=0.001 \
model.optim.betas=[0.9,0.999] \
model.optim.weight_decay=0.0001 \
model.optim.sched.warmup_steps=2000
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
exp_manager.wandb_logger_kwargs.project="<Name of project>"
```
"""

import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.asr.models import EncDecTransfModelBPE
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="../conf/speech_translation/", config_name="fast-conformer_transformer")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecTransfModelBPE(cfg=cfg.model, trainer=trainer)

# Initialize the weights of the model from another model, if provided via config
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if asr_model.prepare_test(trainer):
trainer.test(asr_model)


if __name__ == '__main__':
main()
Loading

0 comments on commit 47e782a

Please sign in to comment.