Skip to content

Commit

Permalink
Add lhotse fixes for rnnt model training and WER hanging issue with f… (
Browse files Browse the repository at this point in the history
#10787)

* Add lhotse fixes for rnnt model training and WER hanging issue with fuse batching

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>
Co-authored-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: nithinraok <nithinraok@users.noreply.github.com>
  • Loading branch information
2 people authored and Nithin Rao Koluguri committed Oct 9, 2024
1 parent 5e22a30 commit 7e30ba6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 37 deletions.
15 changes: 6 additions & 9 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,12 @@ def __init__(self, tokenizer):
def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)
tokens = [
torch.as_tensor(
sum(
(
# Supervisions may come pre-tokenized from the dataloader.
s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)
for s in c.supervisions
),
start=[],
)
torch.cat(
[
torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language))
for s in c.supervisions
],
dim=0,
)
for c in cuts
]
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def __init__(
fold_consecutive=True,
batch_dim_index=0,
dist_sync_on_step=False,
sync_on_compute=True,
):
super().__init__(dist_sync_on_step=dist_sync_on_step)
super().__init__(dist_sync_on_step=dist_sync_on_step, sync_on_compute=sync_on_compute)

self.decoding = decoding
self.use_cer = use_cer
Expand Down
64 changes: 38 additions & 26 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -84,8 +83,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -382,15 +380,20 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to

@classmethod
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
# same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking
torch.where(mask.unsqueeze(-1), src_states[0], dst_states[0], out=dst_states[0])

@classmethod
def batch_replace_states_all(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor],
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -591,8 +594,7 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMi

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -601,8 +603,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -1018,19 +1019,19 @@ def batch_score_hypothesis(

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
"""
Create batch of decoder states.
Create batch of decoder states.
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
# LSTM has 2 states
new_states = [[] for _ in range(len(decoder_states[0]))]
for layer in range(self.pred_rnn_layers):
Expand Down Expand Up @@ -1109,7 +1110,9 @@ def batch_replace_states_mask(

@classmethod
def batch_replace_states_all(
cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor],
cls,
src_states: Tuple[torch.Tensor, torch.Tensor],
dst_states: Tuple[torch.Tensor, torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -1253,8 +1256,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin)

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
Expand All @@ -1266,8 +1268,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
if not self._fuse_loss_wer:
return {
"outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()),
Expand Down Expand Up @@ -1490,6 +1491,10 @@ def forward(
sub_transcripts = sub_transcripts.detach()

# Update WER on each process without syncing
if self.training:
original_sync = self.wer._to_sync
self.wer._to_sync = False

self.wer.update(
predictions=sub_enc,
predictions_lengths=sub_enc_lens,
Expand All @@ -1500,6 +1505,9 @@ def forward(
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()

if self.training:
self.wer._to_sync = original_sync

wers.append(wer)
wer_nums.append(wer_num)
wer_denoms.append(wer_denom)
Expand Down Expand Up @@ -2047,7 +2055,11 @@ def forward(
return losses, wer, wer_num, wer_denom

def sampled_joint(
self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor,
self,
f: torch.Tensor,
g: torch.Tensor,
transcript: torch.Tensor,
transcript_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Compute the sampled joint step of the network.
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ def get_lhotse_dataloader_from_config(
duration_bins=determine_bucket_duration_bins(config),
num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate,
buffer_size=config.bucket_buffer_size,
concurrent=config.concurrent_bucketing,
rank=0 if is_tarred else global_rank,
world_size=1 if is_tarred else world_size,
)
Expand Down

0 comments on commit 7e30ba6

Please sign in to comment.