Skip to content

Commit

Permalink
Merge branch 'main' into format_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tbartley94 authored Jul 27, 2024
2 parents 8fa1e51 + bc6d534 commit 1fecadc
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 54 deletions.
20 changes: 9 additions & 11 deletions nemo/collections/asr/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def move_dimension_to_the_front(tensor, dim_index):
# TODO: Add documentation
class BLEU(SacreBLEUScore):
"""
This metric computes numerator, denominator, hypotheses lengths, and target lengths for Overall Bilingual Evaluation Understudy (BLEU)
between prediction and reference texts. When doing distributed training/evaluation the result of
This metric computes numerator, denominator, hypotheses lengths, and target lengths for Overall Bilingual Evaluation Understudy (BLEU)
between prediction and reference texts. When doing distributed training/evaluation the result of
``res=BLEU.(predictions, predictions_lengths, targets, target_lengths)``
calls will be all-reduced between all workers using SUM operations.
If used with PytorchLightning LightningModule, include bleu_num bleur_den, bleu_pred_len, and bleu_target_len values inside
If used with PytorchLightning LightningModule, include bleu_num bleur_den, bleu_pred_len, and bleu_target_len values inside
validation_step results. Then aggregate (sum) then at the end of validation epoch to correctly compute validation BLEUR.
Example:
Expand Down Expand Up @@ -99,7 +99,6 @@ def __init__(
smooth=smooth,
dist_sync_on_step=dist_sync_on_step,
)
self.has_spl_tokens = False
self.decoding = decoding
self.decode = None
if isinstance(self.decoding, AbstractRNNTDecoding):
Expand All @@ -113,7 +112,6 @@ def __init__(
fold_consecutive=self.fold_consecutive,
)
elif isinstance(self.decoding, AbstractMultiTaskDecoding):
self.has_spl_tokens = True
self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids, targets: self.decoding.decode_predictions_tensor(
encoder_hidden_states=predictions,
encoder_input_mask=predictions_mask,
Expand Down Expand Up @@ -165,10 +163,6 @@ def update(
references.append(reference)
hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets)

if self.has_spl_tokens:
hypotheses = [self.decoding.strip_special_tokens(hyp) for hyp in hypotheses]
references = [self.decoding.strip_special_tokens(ref) for ref in references]

if self.log_prediction:
logging.info(f"\n")
logging.info(f"reference:{references[0]}")
Expand All @@ -185,7 +179,7 @@ def compute(self, return_all_metrics=True, prefix="", suffix=""):
only BLEU. Default: True.
prefix: str to prepend to metric value keys.
suffix: str to append to metric value keys.
Returns:
Dict: key-value pairs of BLEU metrics and values. Keys are prepended and appended with prefix
and suffix flags, respectively.
Expand All @@ -205,7 +199,11 @@ def compute(self, return_all_metrics=True, prefix="", suffix=""):

# Adding wrapper to avoid imports and extra variables over the namespace
def _compute_bleu(
self, predictions_lengths, targets_lengths, numerator, denominator,
self,
predictions_lengths,
targets_lengths,
numerator,
denominator,
):
return _bleu_score_compute(
predictions_lengths, targets_lengths, numerator, denominator, self.n_gram, self.weights, self.smooth
Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def word_error_rate_detail(
def word_error_rate_per_utt(hypotheses: List[str], references: List[str], use_cer=False) -> Tuple[List[float], float]:
"""
Computes Word Error Rate per utterance and the average WER
between two texts represented as corresponding lists of string.
between two texts represented as corresponding lists of string.
Hypotheses and references must have same length.
Args:
Expand Down Expand Up @@ -263,7 +263,6 @@ def __init__(
self.fold_consecutive = fold_consecutive
self.batch_dim_index = batch_dim_index

self.has_spl_tokens = False
self.decode = None
if isinstance(self.decoding, AbstractRNNTDecoding):
self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids, targets: self.decoding.rnnt_decoder_predictions_tensor(
Expand All @@ -276,7 +275,6 @@ def __init__(
fold_consecutive=self.fold_consecutive,
)
elif isinstance(self.decoding, AbstractMultiTaskDecoding):
self.has_spl_tokens = True
self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids, targets: self.decoding.decode_predictions_tensor(
encoder_hidden_states=predictions,
encoder_input_mask=predictions_mask,
Expand Down Expand Up @@ -326,10 +324,6 @@ def update(
references.append(reference)
hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets)

if self.has_spl_tokens:
hypotheses = [self.decoding.strip_special_tokens(hyp) for hyp in hypotheses]
references = [self.decoding.strip_special_tokens(ref) for ref in references]

if self.log_prediction:
logging.info(f"\n")
logging.info(f"reference:{references[0]}")
Expand Down
13 changes: 0 additions & 13 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,19 +918,6 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
return_hypotheses=trcfg.return_hypotheses,
)

if trcfg.return_hypotheses:
for hyp in best_hypotheses:
hyp.text = self.decoding.strip_special_tokens(hyp.text)
if all_hypotheses is not None:
for i in range(len(all_hypotheses)):
for j in range(len(all_hypotheses[i])):
all_hypotheses[i][j].text = self.decoding.strip_special_tokens(all_hypotheses[i][j].text)
else:
best_hypotheses = [self.decoding.strip_special_tokens(text) for text in best_hypotheses]
if all_hypotheses is not None:
for i in range(len(all_hypotheses)):
all_hypotheses[i] = [self.decoding.strip_special_tokens(text) for text in all_hypotheses[i]]

del enc_states, enc_mask, decoder_input_ids
if all_hypotheses is None:
return best_hypotheses
Expand Down
47 changes: 37 additions & 10 deletions nemo/collections/asr/parts/submodules/multitask_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class TransformerAEDBeamInfer(AEDBeamInfer, Typing):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimention -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -116,8 +115,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand All @@ -141,15 +139,18 @@ def __init__(
preserve_alignments=preserve_alignments,
)
self.beam_size = beam_size
self.bos = tokenizer.bos
self.pad = tokenizer.pad
self.eos = tokenizer.eos
self.beam_search = BeamSearchSequenceGenerator(
embedding=transformer_decoder.embedding,
decoder=transformer_decoder.decoder,
log_softmax=log_softmax_module,
max_sequence_length=transformer_decoder.max_sequence_length,
beam_size=beam_size,
bos=tokenizer.bos_id,
pad=tokenizer.pad_id,
eos=tokenizer.eos_id,
bos=self.bos,
pad=self.pad,
eos=self.eos,
len_pen=length_penalty,
max_delta_length=max_generation_delta,
)
Expand Down Expand Up @@ -196,9 +197,9 @@ def forward(
for i in range(len(topk_hypotheses)):
hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.beam_size)]
# Pack results into Hypotheses
packed_result.append(
NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i]))
)
hypotheses = pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])
self.format_hypotheses(hypotheses, decoder_input_ids)
packed_result.append(NBestHypotheses(hypotheses))
else:
beam_scores = [None for _ in range(len(best_hypo))]
best_hypo = best_hypo.detach().cpu()
Expand All @@ -207,9 +208,35 @@ def forward(
]
# Pack results into Hypotheses
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores)
self.format_hypotheses(packed_result, decoder_input_ids)

return (packed_result,)

def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids: torch.Tensor | None) -> None:
"""
For each hypothesis in the mini-batch:
* Remove the decoder input ids (prompt) from the predictions
* Remove BOS, EOS, and PAD ids from the predictions.
Modifies results in-place.
"""
if decoder_input_ids is not None:
assert (
len(packed_result) == decoder_input_ids.shape[0]
), f"Mismatching number of examples {len(packed_result)=} {decoder_input_ids.shape[0]=}"
decoder_input_ids = decoder_input_ids.detach().cpu()
for hyp, prefix in zip(packed_result, decoder_input_ids):
assert (
hyp.y_sequence[: prefix.shape[0]] == prefix
).all(), f"The decoder input IDs were not found at the beginning of prediction: {hyp.y_sequence=} {prefix=})"
hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :]
for hyp in packed_result:
ids = hyp.y_sequence
pos = -1
while ids[pos] == self.pad or ids[pos] == self.eos:
pos -= 1
if pos < -1:
hyp.y_sequence = ids[: pos + 1]


@dataclass
class AEDBeamInferConfig:
Expand Down
11 changes: 0 additions & 11 deletions nemo/collections/asr/parts/submodules/multitask_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,6 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]:
"""
raise NotImplementedError()

def strip_special_tokens(self, text: str):
"""
assuming all special tokens are of format <token>
Note that if any label/pred is of format <token>, it will be stripped
"""
assert isinstance(text, str), f"Expected str, got {type(text)}"
text = re.sub(r'<[^>]+>', '', text)
# strip spaces at the beginning and end;
# this is training data artifact, will be fixed in future (@kpuvvada)
return text.strip()


class MultiTaskDecoding(AbstractMultiTaskDecoding):
"""
Expand Down
2 changes: 1 addition & 1 deletion tutorials/asr/Multilang_ASR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@
"outputs": [],
"source": [
"if not os.path.exists(\"convert_hf_dataset_to_nemo.py\"):\n",
" !wget https://raw.githubusercontent.com/NVIDIA/NeMo/${BRANCH}/scripts/speech_recognition/convert_hf_dataset_to_nemo.py"
" !wget https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/speech_recognition/convert_hf_dataset_to_nemo.py"
]
},
{
Expand Down

0 comments on commit 1fecadc

Please sign in to comment.