Skip to content

Commit

Permalink
Fix NVIDIA#8891 by supported GPU-side batched CTC Greedy Decoding (NV…
Browse files Browse the repository at this point in the history
…IDIA#9100)

* Support batched inference of greedy CTC decoding.

Fixes NVIDIA#8891

Basically, doing max() on CPU one at a time is very very slow. It is
better to do that all on the GPU before we do the copy over to CPU.

This new algorithm has the same interface as the old one and can be accessed by setting strategy to "greedy_batched" rather than "greedy".

Warn when using greedy rather than greedy_batched strategy.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
  • Loading branch information
galv authored May 9, 2024
1 parent 777e9d6 commit 2d2219c
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 11 deletions.
25 changes: 19 additions & 6 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,31 @@ def __init__(self, decoding_cfg, blank_id: int):
self.batch_dim_index = self.cfg.get('batch_dim_index', 0)
self.word_seperator = self.cfg.get('word_seperator', ' ')

possible_strategies = ['greedy', 'beam', 'pyctcdecode', 'flashlight']
possible_strategies = ['greedy', 'greedy_batched', 'beam', 'pyctcdecode', 'flashlight']
if self.cfg.strategy not in possible_strategies:
raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}")

# Update preserve alignments
if self.preserve_alignments is None:
if self.cfg.strategy in ['greedy']:
if self.cfg.strategy in ['greedy', 'greedy_batched']:
self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False)
else:
self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False)

# Update compute timestamps
if self.compute_timestamps is None:
if self.cfg.strategy in ['greedy']:
if self.cfg.strategy in ['greedy', 'greedy_batched']:
self.compute_timestamps = self.cfg.greedy.get('compute_timestamps', False)
elif self.cfg.strategy in ['beam']:
self.compute_timestamps = self.cfg.beam.get('compute_timestamps', False)

# initialize confidence-related fields
self._init_confidence(self.cfg.get('confidence_cfg', None))

# Confidence estimation is not implemented for strategies other than `greedy`
# Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batched`
if (
not self.preserve_frame_confidence
and self.cfg.strategy != 'greedy'
and self.cfg.strategy not in ('greedy', 'greedy_batched')
and self.cfg.beam.get('preserve_frame_confidence', False)
):
raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`")
Expand All @@ -247,6 +247,10 @@ def __init__(self, decoding_cfg, blank_id: int):
self.compute_timestamps |= self.preserve_frame_confidence

if self.cfg.strategy == 'greedy':
logging.warning(
"CTC decoding strategy 'greedy' is slower than 'greedy_batched', which implements the same exact interface. Consider changing your strategy to 'greedy_batched' for a free performance improvement.",
mode=logging_mode.ONCE,
)

self.decoding = ctc_greedy_decoding.GreedyCTCInfer(
blank_id=self.blank_id,
Expand All @@ -256,6 +260,15 @@ def __init__(self, decoding_cfg, blank_id: int):
confidence_method_cfg=self.confidence_method_cfg,
)

elif self.cfg.strategy == "greedy_batched":
self.decoding = ctc_greedy_decoding.GreedyBatchedCTCInfer(
blank_id=self.blank_id,
preserve_alignments=self.preserve_alignments,
compute_timestamps=self.compute_timestamps,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
)

elif self.cfg.strategy == 'beam':

self.decoding = ctc_beam_decoding.BeamCTCInfer(
Expand Down Expand Up @@ -1287,7 +1300,7 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:

@dataclass
class CTCDecodingConfig:
strategy: str = "greedy"
strategy: str = "greedy_batched"

# preserve decoding alignments
preserve_alignments: Optional[bool] = None
Expand Down
223 changes: 222 additions & 1 deletion nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin):
def input_types(self):
"""Returns definitions of module input ports.
"""
# Input can be of dimention -
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

return {
Expand Down Expand Up @@ -266,6 +266,227 @@ def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin):
"""A vectorized greedy CTC decoder.
This is basically always faster than GreedyCTCInfer, and supports
the same interface. See issue #8891 on github for what is wrong
with GreedyCTCInfer. GreedyCTCInfer loops over each element in the
batch, running kernels at batch size one. CPU overheads end up
dominating. This implementation does appropriate masking to
appropriately do the same operation in a batched manner.
Args:
blank_index: int index of the blank token. Can be 0 or len(vocabulary).
preserve_alignments: Bool flag which preserves the history of logprobs generated during
decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.
compute_timestamps: A bool flag, which determines whether to compute the character/subword, or
word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
The timestamps will be available in the returned Hypothesis.timestep as a dictionary.
preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores
generated during decoding. When set to true, the Hypothesis will contain
the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats.
confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
name: The method name (str).
Supported values:
- 'max_prob' for using the maximum token probability as a confidence.
- 'entropy' for using a normalized entropy of a log-likelihood vector.
entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`.
Supported values:
- 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided,
the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
Note that for this entropy, the alpha should comply the following inequality:
(log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1)
where V is the model vocabulary size.
- 'tsallis' for the Tsallis entropy with the Boltzmann constant one.
Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)),
where α is a parameter. When α == 1, it works like the Gibbs entropy.
More: https://en.wikipedia.org/wiki/Tsallis_entropy
- 'renyi' for the Rényi entropy.
Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)),
where α is a parameter. When α == 1, it works like the Gibbs entropy.
More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy
alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0.
When the alpha equals one, scaling is not applied to 'max_prob',
and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i))
entropy_norm: A mapping of the entropy value to the interval [0,1].
Supported values:
- 'lin' for using the linear mapping.
- 'exp' for using exponential mapping with linear shift.
"""

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

return {
"decoder_output": NeuralType(None, LogprobsType()),
"decoder_lengths": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
self,
blank_id: int,
preserve_alignments: bool = False,
compute_timestamps: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
):
super().__init__()

self.blank_id = blank_id
self.preserve_alignments = preserve_alignments
# we need timestamps to extract non-blank per-frame confidence
self.compute_timestamps = compute_timestamps | preserve_frame_confidence
self.preserve_frame_confidence = preserve_frame_confidence

# set confidence calculation method
self._init_confidence_method(confidence_method_cfg)

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Args:
decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label).
decoder_lengths: list of int representing the length of each sequence
output sequence.
Returns:
packed list containing batch number of sentences (Hypotheses).
"""
if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, decoder_lengths)
return (packed_result,)

@torch.no_grad()
def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor):
# x: [B, T, D]
# out_len: [B]

batch_size = x.shape[0]
max_time = x.shape[1]

predictions = x
# In CTC greedy decoding, each output maximum likelihood token
# is calculated independent of the other tokens.
predictions_logprobs, predictions_labels = predictions.max(dim=-1)

# Since predictions_logprobs is a padded matrix in the time
# dimension, we consider invalid timesteps to be "blank".
time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1))
# Sum the non-blank labels to compute the score of the
# transcription. This follows from Eq. (3) of "Connectionist
# Temporal Classification: Labelling Unsegmented Sequence Data
# with Recurrent Neural Networks".
scores = torch.where(non_blank_ids_mask, predictions_logprobs, 0.0).sum(axis=1)

scores = scores.cpu()
predictions_labels = predictions_labels.cpu()
out_len = out_len.cpu()

if self.preserve_alignments or self.preserve_frame_confidence:
predictions = predictions.cpu()

hypotheses = []

# This mimics the for loop in GreedyCTCInfer::forward.
for i in range(batch_size):
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)
hypothesis.score = scores[i]

prediction_labels_no_padding = predictions_labels[i, : out_len[i]].tolist()

assert predictions_labels.dtype == torch.int64
hypothesis.y_sequence = prediction_labels_no_padding

if self.preserve_alignments:
hypothesis.alignments = (
predictions[i, : out_len[i], :].clone(),
predictions_labels[i, : out_len[i]].clone(),
)
if self.compute_timestamps:
# TOOD: Could do this in a vectorized manner... Would
# prefer to have nonzero_static, though, for sanity.
# Or do a prefix sum on out_len
hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist()
if self.preserve_frame_confidence:
hypothesis.frame_confidence = self._get_confidence(predictions[i, : out_len[i], :])

hypotheses.append(hypothesis)

return hypotheses

@torch.no_grad()
def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor):
"""
This does greedy decoding in the case where you have already found the
most likely token at each timestep.
"""
# x: [B, T]
# out_len: [B]

batch_size = x.shape[0]
max_time = x.shape[1]

predictions_labels = x
time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1))
predictions_labels = predictions_labels.cpu()
out_len = out_len.cpu()

hypotheses = []

for i in range(batch_size):
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)
hypothesis.y_sequence = predictions_labels[i, : out_len[i]].tolist()
hypothesis.score = -1.0

if self.preserve_alignments:
raise ValueError(
"Requested for alignments, but predictions provided were labels, not log probabilities."
)
if self.compute_timestamps:
# TOOD: Could do this in a vectorized manner... Would
# prefer to have nonzero_static, though, for sanity.
# Or do a prefix sum on out_len
hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist()
if self.preserve_frame_confidence:
raise ValueError(
"Requested for per-frame confidence, but predictions provided were labels, not log probabilities."
)

hypotheses.append(hypothesis)

return hypotheses

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


@dataclass
class GreedyCTCInferConfig:
preserve_alignments: bool = False
Expand Down
Loading

0 comments on commit 2d2219c

Please sign in to comment.