Skip to content

Commit

Permalink
TDT confidence fix (NVIDIA#8982)
Browse files Browse the repository at this point in the history
* tdt confidence fix

---------

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
GNroy and pre-commit-ci[bot] authored Apr 21, 2024
1 parent 6533e48 commit 9bafd37
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 51 deletions.
13 changes: 13 additions & 0 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class AbstractCTCDecoding(ConfidenceMixin):
Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
method_cfg:
A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -911,10 +915,15 @@ class CTCDecoding(AbstractCTCDecoding):
exclude_blank:
Bool flag indicating that blank token confidence scores are to be excluded
from the `token_confidence`.
aggregation:
Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
method_cfg:
A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -1122,6 +1131,10 @@ class CTCBPEDecoding(AbstractCTCDecoding):
Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
method_cfg:
A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down
101 changes: 80 additions & 21 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class AbstractRNNTDecoding(ConfidenceMixin):
from the `token_confidence`.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -209,7 +212,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
self.compute_timestamps = self.cfg.get('compute_timestamps', None)
self.word_seperator = self.cfg.get('word_seperator', ' ')

if self.durations is not None and self.durations != []: # this means it's a TDT model.
self._is_tdt = self.durations is not None and self.durations != [] # this means it's a TDT model.
if self._is_tdt:
if blank_id == 0:
raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models")
if self.big_blank_durations is not None and self.big_blank_durations != []:
Expand Down Expand Up @@ -254,6 +258,12 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
# initialize confidence-related fields
self._init_confidence(self.cfg.get('confidence_cfg', None))

if self._is_tdt:
if self.preserve_frame_confidence is True and self.preserve_alignments is False:
raise ValueError(
"If `preserve_frame_confidence` flag is set, then `preserve_alignments` flag must also be set."
)

# Confidence estimation is not implemented for these strategies
if (
not self.preserve_frame_confidence
Expand All @@ -264,7 +274,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):

if self.cfg.strategy == 'greedy':
if self.big_blank_durations is None or self.big_blank_durations == []:
if self.durations is None or self.durations == []:
if not self._is_tdt:
self.decoding = rnnt_greedy_decoding.GreedyRNNTInfer(
decoder_model=decoder,
joint_model=joint,
Expand All @@ -289,6 +299,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
)
else:
Expand All @@ -307,7 +318,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):

elif self.cfg.strategy == 'greedy_batch':
if self.big_blank_durations is None or self.big_blank_durations == []:
if self.durations is None or self.durations == []:
if not self._is_tdt:
self.decoding = rnnt_greedy_decoding.GreedyBatchedRNNTInfer(
decoder_model=decoder,
joint_model=joint,
Expand All @@ -334,6 +345,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
)
Expand Down Expand Up @@ -530,7 +542,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp
if self.big_blank_durations is not None and self.big_blank_durations != []: # multi-blank RNNT
num_extra_outputs = len(self.big_blank_durations)
prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs]
elif self.durations is not None and self.durations != []: # TDT model.
elif self._is_tdt: # TDT model.
prediction = [p for p in prediction if p < self.blank_id]
else: # standard RNN-T
prediction = [p for p in prediction if p != self.blank_id]
Expand Down Expand Up @@ -569,28 +581,69 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
Returns:
A list of hypotheses with high-level confidence scores.
"""
if self.exclude_blank_from_confidence:
for hyp in hypotheses_list:
hyp.token_confidence = hyp.non_blank_frame_confidence
else:
if self._is_tdt:
# if self.tdt_include_duration_confidence is True then frame_confidence elements consist of two numbers
maybe_pre_aggregate = (
(lambda x: self._aggregate_confidence(x)) if self.tdt_include_duration_confidence else (lambda x: x)
)
for hyp in hypotheses_list:
offset = 0
token_confidence = []
if len(hyp.timestep) > 0:
for ts, te in zip(hyp.timestep, hyp.timestep[1:] + [len(hyp.frame_confidence)]):
if ts != te:
# <blank> tokens are considered to belong to the last non-blank token, if any.
token_confidence.append(
self._aggregate_confidence(
[hyp.frame_confidence[ts][offset]]
+ [fc[0] for fc in hyp.frame_confidence[ts + 1 : te]]
# trying to recover frame_confidence according to alignments
subsequent_blank_confidence = []
# going backwards since <blank> tokens are considered belonging to the last non-blank token.
for fc, fa in zip(hyp.frame_confidence[::-1], hyp.alignments[::-1]):
# there is only one score per frame most of the time
if len(fa) > 1:
for i, a in reversed(list(enumerate(fa))):
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
elif not subsequent_blank_confidence:
token_confidence.append(maybe_pre_aggregate(fc[i]))
else:
token_confidence.append(
self._aggregate_confidence(
[maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence
)
)
)
offset = 0
subsequent_blank_confidence = []
else:
i, a = 0, fa[0]
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
elif not subsequent_blank_confidence:
token_confidence.append(maybe_pre_aggregate(fc[i]))
else:
token_confidence.append(hyp.frame_confidence[ts][offset])
offset += 1
token_confidence.append(
self._aggregate_confidence([maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence)
)
subsequent_blank_confidence = []
token_confidence = token_confidence[::-1]
hyp.token_confidence = token_confidence
else:
if self.exclude_blank_from_confidence:
for hyp in hypotheses_list:
hyp.token_confidence = hyp.non_blank_frame_confidence
else:
for hyp in hypotheses_list:
offset = 0
token_confidence = []
if len(hyp.timestep) > 0:
for ts, te in zip(hyp.timestep, hyp.timestep[1:] + [len(hyp.frame_confidence)]):
if ts != te:
# <blank> tokens are considered to belong to the last non-blank token, if any.
token_confidence.append(
self._aggregate_confidence(
[hyp.frame_confidence[ts][offset]]
+ [fc[0] for fc in hyp.frame_confidence[ts + 1 : te]]
)
)
offset = 0
else:
token_confidence.append(hyp.frame_confidence[ts][offset])
offset += 1
hyp.token_confidence = token_confidence
if self.preserve_word_confidence:
for hyp in hypotheses_list:
hyp.word_confidence = self._aggregate_token_confidence(hyp)
Expand Down Expand Up @@ -1010,6 +1063,9 @@ class RNNTDecoding(AbstractRNNTDecoding):
from the `token_confidence`.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -1276,6 +1332,9 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):
from the `token_confidence`.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down
19 changes: 18 additions & 1 deletion nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,7 @@ class GreedyRNNTInferConfig:
max_symbols_per_step: Optional[int] = 10
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())

def __post_init__(self):
Expand All @@ -2298,6 +2299,7 @@ class GreedyBatchedRNNTInferConfig:
max_symbols_per_step: Optional[int] = 10
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
use_cuda_graph_decoder: bool = False
Expand Down Expand Up @@ -2337,6 +2339,9 @@ class GreedyTDTInfer(_GreedyRNNTInfer):
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores.
U is the number of target tokens for the current timestep Ti.
include_duration_confidence: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -2380,6 +2385,7 @@ def __init__(
max_symbols_per_step: Optional[int] = None,
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
):
super().__init__(
Expand All @@ -2392,6 +2398,7 @@ def __init__(
confidence_method_cfg=confidence_method_cfg,
)
self.durations = durations
self.include_duration_confidence = include_duration_confidence

@typecheck()
def forward(
Expand Down Expand Up @@ -2517,7 +2524,11 @@ def _greedy_decode(

if self.preserve_frame_confidence:
# insert confidence into last timestep
hypothesis.frame_confidence[-1].append(self._get_confidence(logp))
hypothesis.frame_confidence[-1].append(
(self._get_confidence_tensor(logp), self._get_confidence_tensor(duration_logp))
if self.include_duration_confidence
else self._get_confidence_tensor(logp)
)

del logp

Expand Down Expand Up @@ -2593,6 +2604,9 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer):
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores.
U is the number of target tokens for the current timestep Ti.
include_duration_confidence: Bool flag indicating that the duration confidence scores are to be calculated and
attached to the regular frame confidence,
making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).
confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -2636,6 +2650,7 @@ def __init__(
max_symbols_per_step: Optional[int] = None,
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
use_cuda_graph_decoder: bool = False,
):
Expand All @@ -2649,6 +2664,7 @@ def __init__(
confidence_method_cfg=confidence_method_cfg,
)
self.durations = durations
self.include_duration_confidence = include_duration_confidence

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
Expand All @@ -2663,6 +2679,7 @@ def __init__(
max_symbols_per_step=self.max_symbols,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
include_duration_confidence=include_duration_confidence,
confidence_method_cfg=confidence_method_cfg,
allow_cuda_graphs=use_cuda_graph_decoder,
)
Expand Down
Loading

0 comments on commit 9bafd37

Please sign in to comment.