diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index d331a6c86b53..67559eccf6e2 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -98,6 +98,10 @@ class AbstractCTCDecoding(ConfidenceMixin): Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -911,10 +915,15 @@ class CTCDecoding(AbstractCTCDecoding): exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1122,6 +1131,10 @@ class CTCBPEDecoding(AbstractCTCDecoding): Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 7a260f3c6c89..71079f4b6382 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -96,6 +96,9 @@ class AbstractRNNTDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -209,7 +212,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.compute_timestamps = self.cfg.get('compute_timestamps', None) self.word_seperator = self.cfg.get('word_seperator', ' ') - if self.durations is not None and self.durations != []: # this means it's a TDT model. + self._is_tdt = self.durations is not None and self.durations != [] # this means it's a TDT model. + if self._is_tdt: if blank_id == 0: raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") if self.big_blank_durations is not None and self.big_blank_durations != []: @@ -254,6 +258,12 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) + if self._is_tdt: + if self.preserve_frame_confidence is True and self.preserve_alignments is False: + raise ValueError( + "If `preserve_frame_confidence` flag is set, then `preserve_alignments` flag must also be set." + ) + # Confidence estimation is not implemented for these strategies if ( not self.preserve_frame_confidence @@ -264,7 +274,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.cfg.strategy == 'greedy': if self.big_blank_durations is None or self.big_blank_durations == []: - if self.durations is None or self.durations == []: + if not self._is_tdt: self.decoding = rnnt_greedy_decoding.GreedyRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -289,6 +299,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, + include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, ) else: @@ -307,7 +318,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): elif self.cfg.strategy == 'greedy_batch': if self.big_blank_durations is None or self.big_blank_durations == []: - if self.durations is None or self.durations == []: + if not self._is_tdt: self.decoding = rnnt_greedy_decoding.GreedyBatchedRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -334,6 +345,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, + include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) @@ -530,7 +542,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp if self.big_blank_durations is not None and self.big_blank_durations != []: # multi-blank RNNT num_extra_outputs = len(self.big_blank_durations) prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs] - elif self.durations is not None and self.durations != []: # TDT model. + elif self._is_tdt: # TDT model. prediction = [p for p in prediction if p < self.blank_id] else: # standard RNN-T prediction = [p for p in prediction if p != self.blank_id] @@ -569,28 +581,69 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes Returns: A list of hypotheses with high-level confidence scores. """ - if self.exclude_blank_from_confidence: - for hyp in hypotheses_list: - hyp.token_confidence = hyp.non_blank_frame_confidence - else: + if self._is_tdt: + # if self.tdt_include_duration_confidence is True then frame_confidence elements consist of two numbers + maybe_pre_aggregate = ( + (lambda x: self._aggregate_confidence(x)) if self.tdt_include_duration_confidence else (lambda x: x) + ) for hyp in hypotheses_list: - offset = 0 token_confidence = [] - if len(hyp.timestep) > 0: - for ts, te in zip(hyp.timestep, hyp.timestep[1:] + [len(hyp.frame_confidence)]): - if ts != te: - # tokens are considered to belong to the last non-blank token, if any. - token_confidence.append( - self._aggregate_confidence( - [hyp.frame_confidence[ts][offset]] - + [fc[0] for fc in hyp.frame_confidence[ts + 1 : te]] + # trying to recover frame_confidence according to alignments + subsequent_blank_confidence = [] + # going backwards since tokens are considered belonging to the last non-blank token. + for fc, fa in zip(hyp.frame_confidence[::-1], hyp.alignments[::-1]): + # there is only one score per frame most of the time + if len(fa) > 1: + for i, a in reversed(list(enumerate(fa))): + if a[-1] == self.blank_id: + if not self.exclude_blank_from_confidence: + subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i])) + elif not subsequent_blank_confidence: + token_confidence.append(maybe_pre_aggregate(fc[i])) + else: + token_confidence.append( + self._aggregate_confidence( + [maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence + ) ) - ) - offset = 0 + subsequent_blank_confidence = [] + else: + i, a = 0, fa[0] + if a[-1] == self.blank_id: + if not self.exclude_blank_from_confidence: + subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i])) + elif not subsequent_blank_confidence: + token_confidence.append(maybe_pre_aggregate(fc[i])) else: - token_confidence.append(hyp.frame_confidence[ts][offset]) - offset += 1 + token_confidence.append( + self._aggregate_confidence([maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence) + ) + subsequent_blank_confidence = [] + token_confidence = token_confidence[::-1] hyp.token_confidence = token_confidence + else: + if self.exclude_blank_from_confidence: + for hyp in hypotheses_list: + hyp.token_confidence = hyp.non_blank_frame_confidence + else: + for hyp in hypotheses_list: + offset = 0 + token_confidence = [] + if len(hyp.timestep) > 0: + for ts, te in zip(hyp.timestep, hyp.timestep[1:] + [len(hyp.frame_confidence)]): + if ts != te: + # tokens are considered to belong to the last non-blank token, if any. + token_confidence.append( + self._aggregate_confidence( + [hyp.frame_confidence[ts][offset]] + + [fc[0] for fc in hyp.frame_confidence[ts + 1 : te]] + ) + ) + offset = 0 + else: + token_confidence.append(hyp.frame_confidence[ts][offset]) + offset += 1 + hyp.token_confidence = token_confidence if self.preserve_word_confidence: for hyp in hypotheses_list: hyp.word_confidence = self._aggregate_token_confidence(hyp) @@ -1010,6 +1063,9 @@ class RNNTDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1276,6 +1332,9 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 464dc46e358c..e5de99cf0776 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2282,6 +2282,7 @@ class GreedyRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False + tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): @@ -2298,6 +2299,7 @@ class GreedyBatchedRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False + tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) loop_labels: bool = True use_cuda_graph_decoder: bool = False @@ -2337,6 +2339,9 @@ class GreedyTDTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. + include_duration_confidence: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -2380,6 +2385,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, + include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, ): super().__init__( @@ -2392,6 +2398,7 @@ def __init__( confidence_method_cfg=confidence_method_cfg, ) self.durations = durations + self.include_duration_confidence = include_duration_confidence @typecheck() def forward( @@ -2517,7 +2524,11 @@ def _greedy_decode( if self.preserve_frame_confidence: # insert confidence into last timestep - hypothesis.frame_confidence[-1].append(self._get_confidence(logp)) + hypothesis.frame_confidence[-1].append( + (self._get_confidence_tensor(logp), self._get_confidence_tensor(duration_logp)) + if self.include_duration_confidence + else self._get_confidence_tensor(logp) + ) del logp @@ -2593,6 +2604,9 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. + include_duration_confidence: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -2636,6 +2650,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, + include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, use_cuda_graph_decoder: bool = False, ): @@ -2649,6 +2664,7 @@ def __init__( confidence_method_cfg=confidence_method_cfg, ) self.durations = durations + self.include_duration_confidence = include_duration_confidence # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique @@ -2663,6 +2679,7 @@ def __init__( max_symbols_per_step=self.max_symbols, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, + include_duration_confidence=include_duration_confidence, confidence_method_cfg=confidence_method_cfg, allow_cuda_graphs=use_cuda_graph_decoder, ) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index c289ce06cdfa..b136446d97fb 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -92,6 +92,7 @@ def __init__( logits_dim: int, preserve_alignments=False, preserve_frame_confidence=False, + include_duration_confidence: bool = False, ): """ @@ -105,6 +106,7 @@ def __init__( logits_dim: output dimension for Joint preserve_alignments: if alignments are needed preserve_frame_confidence: if frame confidence is needed + include_duration_confidence: if duration confidence is needed to be added to the frame confidence """ self.device = device self.float_dtype = float_dtype @@ -151,6 +153,7 @@ def __init__( float_dtype=self.float_dtype, store_alignments=preserve_alignments, store_frame_confidence=preserve_frame_confidence, + with_duration_confidence=include_duration_confidence, ) else: self.alignments = None @@ -186,6 +189,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments=False, preserve_frame_confidence=False, + include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, allow_cuda_graphs: bool = True, ): @@ -199,6 +203,7 @@ def __init__( max_symbols_per_step: max symbols to emit on each step (to avoid infinite looping) preserve_alignments: if alignments are needed preserve_frame_confidence: if frame confidence is needed + include_duration_confidence: if duration confidence is needed to be added to the frame confidence confidence_method_cfg: config for the confidence """ super().__init__() @@ -210,6 +215,7 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence + self.include_duration_confidence = include_duration_confidence self._SOS = self._blank_index self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) assert self._SOS == self._blank_index # "blank as pad" algorithm only @@ -244,6 +250,7 @@ def loop_labels_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) + dtype = encoder_output_projected.dtype # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state # init empty batched hypotheses @@ -251,7 +258,7 @@ def loop_labels_torch( batch_size=batch_size, init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=dtype, ) # sample state, will be replaced further when the decoding for hypothesis is done last_decoder_state = self.decoder.initialize_state(encoder_output_projected) @@ -263,9 +270,10 @@ def loop_labels_torch( logits_dim=self.joint.num_classes_with_blank, init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=dtype, store_alignments=self.preserve_alignments, store_frame_confidence=self.preserve_frame_confidence, + with_duration_confidence=self.include_duration_confidence, ) # durations @@ -327,7 +335,19 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)) + confidence=torch.stack( + ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=dtype + ), + self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dtype=dtype + ), + ), + dim=-1, + ) + if self.include_duration_confidence + else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(dtype=dtype) if self.preserve_frame_confidence else None, ) @@ -367,7 +387,21 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)) + confidence=torch.stack( + ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=dtype + ), + self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dtype=dtype + ), + ), + dim=-1, + ) + if self.include_duration_confidence + else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=dtype + ) if self.preserve_frame_confidence else None, ) @@ -520,6 +554,7 @@ def _graph_reinitialize( logits_dim=self.joint.num_classes_with_blank, preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, + include_duration_confidence=self.include_duration_confidence, ) self.state.all_durations = self.durations.to(self.state.device) @@ -616,6 +651,7 @@ def _before_inner_loop_get_joint_output(self): # stage 2: get joint output, iteratively seeking for non-blank labels # blank label in `labels` tensor means "end of hypothesis" (for this index) self.state.active_mask_prev.copy_(self.state.active_mask, non_blocking=True) + dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -644,9 +680,21 @@ def _before_inner_loop_get_joint_output(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=self.state.labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor( - F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + confidence=torch.stack( + ( + self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype), + self._get_confidence_tensor( + F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) + ).to(dtype=dtype), + ), + dim=-1, ) + if self.include_duration_confidence + else self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype) if self.preserve_frame_confidence else None, ) @@ -672,6 +720,7 @@ def _inner_loop_code(self): self.state.time_indices_current_labels, out=self.state.time_indices_current_labels, ) + dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -698,9 +747,21 @@ def _inner_loop_code(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor( - F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + confidence=torch.stack( + ( + self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype), + self._get_confidence_tensor( + F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) + ).to(dtype=dtype), + ), + dim=-1, ) + if self.include_duration_confidence + else self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype) if self.preserve_frame_confidence else None, ) diff --git a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py index 8b15bc22eac6..96f90bee363c 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py @@ -172,7 +172,7 @@ def apply_confidence_parameters(decoding_cfg, hp): Updated decoding config. """ new_decoding_cfg = copy.deepcopy(decoding_cfg) - confidence_cfg_fields = ("aggregation", "exclude_blank") + confidence_cfg_fields = ("aggregation", "exclude_blank", "tdt_include_duration") confidence_method_cfg_fields = ("name", "alpha", "entropy_type", "entropy_norm") with open_dict(new_decoding_cfg): for p, v in hp.items(): diff --git a/nemo/collections/asr/parts/utils/asr_confidence_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_utils.py index 27ced569b1a9..20f75baf522e 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_utils.py @@ -136,6 +136,9 @@ class ConfidenceConfig: from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -175,6 +178,7 @@ class ConfidenceConfig: preserve_word_confidence: bool = False exclude_blank: bool = True aggregation: str = "min" + tdt_include_duration: bool = False method_cfg: ConfidenceMethodConfig = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): @@ -361,6 +365,7 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): confidence_cfg.get('preserve_frame_confidence', False) | self.preserve_token_confidence ) self.exclude_blank_from_confidence = confidence_cfg.get('exclude_blank', True) + self.tdt_include_duration_confidence = confidence_cfg.get('tdt_include_duration', False) self.word_confidence_aggregation = confidence_cfg.get('aggregation', "min") # define aggregation functions @@ -368,8 +373,8 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): self._aggregate_confidence = self.confidence_aggregation_bank[self.word_confidence_aggregation] # Update preserve frame confidence - if self.preserve_frame_confidence is False: - if self.cfg.strategy in ['greedy', 'greedy_batch']: + if self.cfg.strategy in ['greedy', 'greedy_batch']: + if not self.preserve_frame_confidence: self.preserve_frame_confidence = self.cfg.greedy.get('preserve_frame_confidence', False) # OmegaConf.structured ensures that post_init check is always executed confidence_method_cfg = OmegaConf.structured(self.cfg.greedy).get('confidence_method_cfg', None) @@ -378,6 +383,8 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): if confidence_method_cfg is None else OmegaConf.structured(ConfidenceMethodConfig(**confidence_method_cfg)) ) + if not self.tdt_include_duration_confidence: + self.tdt_include_duration_confidence = self.cfg.greedy.get('tdt_include_duration_confidence', False) @abstractmethod def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]: diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 1cd2d2ddc255..158fe3609286 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -115,7 +115,7 @@ def non_blank_frame_confidence(self) -> List[float]: non_blank_frame_confidence = [] # self.timestep can be a dict for RNNT timestep = self.timestep['timestep'] if isinstance(self.timestep, dict) else self.timestep - if len(self.timestep) != 0 and self.frame_confidence is not None: + if len(timestep) != 0 and self.frame_confidence is not None: if any(isinstance(i, list) for i in self.frame_confidence): # rnnt t_prev = -1 offset = 0 @@ -405,6 +405,7 @@ def __init__( float_dtype: Optional[torch.dtype] = None, store_alignments: bool = True, store_frame_confidence: bool = False, + with_duration_confidence: bool = False, ): """ @@ -422,6 +423,7 @@ def __init__( if batch_size <= 0: raise ValueError(f"batch_size must be > 0, got {batch_size}") self.with_frame_confidence = store_frame_confidence + self.with_duration_confidence = with_duration_confidence self.with_alignments = store_alignments self._max_length = init_length @@ -442,7 +444,11 @@ def __init__( self.frame_confidence = torch.zeros(0, device=device, dtype=float_dtype) if self.with_frame_confidence: # tensor to store frame confidence - self.frame_confidence = torch.zeros((batch_size, self._max_length), device=device, dtype=float_dtype) + self.frame_confidence = torch.zeros( + [batch_size, self._max_length, 2] if self.with_duration_confidence else [batch_size, self._max_length], + device=device, + dtype=float_dtype, + ) self._batch_indices = torch.arange(batch_size, device=device) def clear_(self): @@ -462,7 +468,7 @@ def _allocate_more(self): self.logits = torch.cat((self.logits, torch.zeros_like(self.logits)), dim=1) self.labels = torch.cat((self.labels, torch.zeros_like(self.labels)), dim=-1) if self.with_frame_confidence: - self.frame_confidence = torch.cat((self.frame_confidence, torch.zeros_like(self.frame_confidence)), dim=-1) + self.frame_confidence = torch.cat((self.frame_confidence, torch.zeros_like(self.frame_confidence)), dim=1) self._max_length *= 2 def add_results_( diff --git a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py index 8a3c3f4e47c0..0c119b02ff7b 100644 --- a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py +++ b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py @@ -82,6 +82,7 @@ def get_experiment_params(cfg): String with the experiment name. """ blank = "no_blank" if cfg.exclude_blank else "blank" + duration = "duration" if cfg.tdt_include_duration else "no_duration" aggregation = cfg.aggregation method_name = cfg.method_cfg.name alpha = cfg.method_cfg.alpha @@ -91,15 +92,24 @@ def get_experiment_params(cfg): experiment_param_list = [ aggregation, str(cfg.exclude_blank), + str(cfg.tdt_include_duration), method_name, entropy_type, entropy_norm, str(alpha), ] - experiment_str = "-".join([aggregation, blank, method_name, entropy_type, entropy_norm, str(alpha)]) + experiment_str = "-".join([aggregation, blank, duration, method_name, entropy_type, entropy_norm, str(alpha)]) else: - experiment_param_list = [aggregation, str(cfg.exclude_blank), method_name, "-", "-", str(alpha)] - experiment_str = "-".join([aggregation, blank, method_name, str(alpha)]) + experiment_param_list = [ + aggregation, + str(cfg.exclude_blank), + str(cfg.tdt_include_duration), + method_name, + "-", + "-", + str(alpha), + ] + experiment_str = "-".join([aggregation, blank, duration, method_name, str(alpha)]) return experiment_param_list, experiment_str @@ -214,6 +224,7 @@ def main(cfg: ConfidenceBenchmarkingConfig): "model_type", "aggregation", "blank", + "duration", "method_name", "entropy_type", "entropy_norm", diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 85156bf9e2c5..018c9bcc4aa2 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -265,7 +265,7 @@ def test_decoding_type_change(self, hybrid_asr_model): @pytest.mark.unit def test_GreedyRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS @@ -279,7 +279,7 @@ def test_GreedyRNNTInferConfig(self): @pytest.mark.unit def test_GreedyBatchedRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index d5ab0054ff87..a6e3714f20f5 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -387,7 +387,7 @@ def test_decoding_change(self, asr_model): @pytest.mark.unit def test_GreedyRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS @@ -401,7 +401,7 @@ def test_GreedyRNNTInferConfig(self): @pytest.mark.unit def test_GreedyBatchedRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS diff --git a/tutorials/asr/ASR_Confidence_Estimation.ipynb b/tutorials/asr/ASR_Confidence_Estimation.ipynb index ffcec8e16f39..eb8cd7b11688 100644 --- a/tutorials/asr/ASR_Confidence_Estimation.ipynb +++ b/tutorials/asr/ASR_Confidence_Estimation.ipynb @@ -466,6 +466,7 @@ " preserve_word_confidence=True,\n", " aggregation=\"prod\", # How to aggregate frame scores to token scores and token scores to word scores\n", " exclude_blank=False, # If true, only non-blank emissions contribute to confidence scores\n", + " tdt_include_duration=False, # If true, calculate duration confidence for the TDT models\n", " method_cfg=ConfidenceMethodConfig( # Config for per-frame scores calculation (before aggregation)\n", " name=\"max_prob\", # Or \"entropy\" (default), which usually works better\n", " entropy_type=\"gibbs\", # Used only for name == \"entropy\". Recommended: \"tsallis\" (default) or \"renyi\"\n", @@ -506,7 +507,7 @@ "outputs": [], "source": [ "current_test_set = test_sets[\"test_other\"]\n", - "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", "if is_rnnt:\n", " transcriptions = transcriptions[0]" ] @@ -530,12 +531,25 @@ }, "outputs": [], "source": [ + "def round_confidence(confidence_number, ndigits=3):\n", + " if isinstance(confidence_number, float):\n", + " return round(confidence_number, ndigits)\n", + " elif len(confidence_number.size()) == 0: # torch.tensor with one element\n", + " return round(confidence_number.item(), ndigits)\n", + " elif len(confidence_number.size()) == 1: # torch.tensor with a list if elements\n", + " return [round(c.item(), ndigits) for c in confidence_number]\n", + " else:\n", + " raise RuntimeError(f\"Unexpected confidence_number: `{confidence_number}`\")\n", + "\n", + "\n", "tran = transcriptions[0]\n", "print(\n", " f\"\"\" Recognized text: `{tran.text}`\\n\n", - " Word confidence: {[round(c, 3) for c in tran.word_confidence]}\\n\n", - " Token confidence: {[round(c, 3) for c in tran.token_confidence]}\\n\n", - " Frame confidence: {[([round(cc, 3) for cc in c] if is_rnnt else round(c, 3)) for c in tran.frame_confidence]}\"\"\"\n", + " Word confidence: {[round_confidence(c) for c in tran.word_confidence]}\\n\n", + " Token confidence: {[round_confidence(c) for c in tran.token_confidence]}\\n\n", + " Frame confidence: {\n", + " [([round_confidence(cc) for cc in c] if is_rnnt else round_confidence(c)) for c in tran.frame_confidence]\n", + " }\"\"\"\n", ")" ] }, @@ -726,7 +740,7 @@ " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", ")\n", "\n", - "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", "if is_rnnt:\n", " transcriptions = transcriptions[0]" ] @@ -1067,7 +1081,7 @@ " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", ")\n", "\n", - "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", "if is_rnnt:\n", " transcriptions = transcriptions[0]\n", "\n", @@ -1238,7 +1252,7 @@ ")\n", "\n", "noise_transcriptions = model.transcribe(\n", - " paths2audio_files=noise_data.filepaths, batch_size=4, return_hypotheses=True, num_workers=4\n", + " audio=noise_data.filepaths, batch_size=4, return_hypotheses=True, num_workers=4\n", ")\n", "if is_rnnt:\n", " noise_transcriptions = noise_transcriptions[0]" @@ -1424,7 +1438,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4,