Skip to content

Commit

Permalink
comment/style/variable name changes
Browse files Browse the repository at this point in the history
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
Hainan Xu committed Nov 29, 2022
1 parent 02eb07b commit 881240c
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 199 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# It contains the default values for training a Conformer-Transducer ASR model, large size (~120M) with Transducer loss and sub-word encoding.
# It contains the default values for training a Multiblank Conformer-Transducer ASR model with stateless decoders, large size (~120M) with Transducer loss and sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
Expand All @@ -17,10 +17,10 @@
#

# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer
# Multiblank transducer is decribed in https://arxiv.org/pdf/2211.03541
# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html
# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large

name: "Conformer-Transducer-BPE"
name: "Multiblank-Conformer-Transducer-BPE"

model:
sample_rate: 16000
Expand Down
1 change: 0 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class TranscriptionConfig:
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()

# Decoding strategy for RNNT models
# rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1, strategy='greedy')
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)


Expand Down
123 changes: 50 additions & 73 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,23 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, big_blank_idx_list: list,
clamp = loss_kwargs.pop('clamp', -1.0)
big_blank_durations = loss_kwargs.pop('big_blank_durations', None)
sigma = loss_kwargs.pop('sigma', 0.0)
loss_func = MultiblankRNNTLossNumba(blank=blank_idx, big_blank_durations=big_blank_durations, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp, sigma=sigma)
loss_func = MultiblankRNNTLossNumba(
blank=blank_idx,
big_blank_durations=big_blank_durations,
reduction='none',
fastemit_lambda=fastemit_lambda,
clamp=clamp,
sigma=sigma,
)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

elif loss_name == 'multiblank_rnnt_pytorch':
big_blank_durations = loss_kwargs.pop('big_blank_durations', None)
sigma = loss_kwargs.pop('sigma', 0.0)
big_blank_labels = [blank_idx + 1 + i for i in range(len(big_blank_durations))]
loss_func = MultiblankRNNTLossPytorch(blank=blank_idx, big_blank_durations=big_blank_durations, reduction='none', sigma=sigma)
loss_func = MultiblankRNNTLossPytorch(
blank=blank_idx, big_blank_durations=big_blank_durations, reduction='none', sigma=sigma
)

else:
raise ValueError(
Expand All @@ -203,7 +212,8 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, big_blank_idx_list: list,

return loss_func

class RNNTLossPytorch(Loss):

class MultiblankRNNTLossPytorch(Loss):
@property
def input_types(self):
"""Input types definitions for CTCLoss.
Expand Down Expand Up @@ -247,42 +257,62 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens):
if t == 0:
log_alpha[:, t, u] = 0.0
else:
log_alpha[:, t, u] = log_alpha[:, t-1, u] + acts[:, t-1, 0, self.blank]
log_alpha[:, t, u] = log_alpha[:, t - 1, u] + acts[:, t - 1, 0, self.blank]

for i, d in enumerate(self.big_blank_durations):
if t >= d:
if t >= d:
tt = log_alpha[:, t - d, u] + acts[:, t - d, 0, self.blank + 1 + i]
log_alpha[:, t, u] = torch.logsumexp(torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0)

log_alpha[:, t, u] = torch.logsumexp(
torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0
)

else:
if t == 0:
gathered = torch.gather(acts[:, t, u-1], dim=1, index=labels[:,u-1].view(-1,1).type(torch.int64) ).reshape(-1)
log_alpha[:, t, u] = log_alpha[:, t,u-1] + gathered.cuda()
gathered = torch.gather(
acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64)
).reshape(-1)
log_alpha[:, t, u] = log_alpha[:, t, u - 1] + gathered.cuda()
else:
log_alpha[:, t, u] = torch.logsumexp(torch.stack([
log_alpha[:, t-1, u] + acts[:, t-1, u, self.blank],
log_alpha[:, t, u-1] + torch.gather(acts[:, t, u-1], dim=1, index=labels[:,u-1].view(-1,1).type(torch.int64) ).reshape(-1)
]), dim=0)
log_alpha[:, t, u] = torch.logsumexp(
torch.stack(
[
log_alpha[:, t - 1, u] + acts[:, t - 1, u, self.blank],
log_alpha[:, t, u - 1]
+ torch.gather(
acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64)
).reshape(-1),
]
),
dim=0,
)

for i, d in enumerate(self.big_blank_durations):
if t >= d:
if t >= d:
tt = log_alpha[:, t - d, u] + acts[:, t - d, u, self.blank + 1 + i]
log_alpha[:, t, u] = torch.logsumexp(torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0)
log_alpha[:, t, u] = torch.logsumexp(
torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0
)

log_probs = []
for b in range(B):
to_append = log_alpha[b, act_lens[b]-1, label_lens[b]] + acts[b, act_lens[b]-1, label_lens[b], self.blank]
to_append = (
log_alpha[b, act_lens[b] - 1, label_lens[b]] + acts[b, act_lens[b] - 1, label_lens[b], self.blank]
)

for i, d in enumerate(self.big_blank_durations):
if act_lens[b] >= d:
tt = log_alpha[b, act_lens[b] - d, label_lens[b]] + acts[b, act_lens[b] - d, label_lens[b], self.blank + 1 + i]
if act_lens[b] >= d:
tt = (
log_alpha[b, act_lens[b] - d, label_lens[b]]
+ acts[b, act_lens[b] - d, label_lens[b], self.blank + 1 + i]
)
to_append = torch.logsumexp(torch.stack([1.0 * to_append, tt]), dim=0)

log_probs.append(to_append)

log_prob = torch.stack(log_probs)
log_prob = torch.stack(log_probs)
return log_prob


class RNNTLoss(Loss):
@property
def input_types(self):
Expand Down Expand Up @@ -414,56 +444,3 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
)

return loss

if __name__ == "__main__":
B, T, U, V = 1, 11, 7, 5
B, T, U, V = 1, 2, 2, 2 # V is number of non blank labels
B, T, U, V = 1, 3, 3, 5 # V is number of non blank labels
B, T, U, V = 8, 264, 32, 128 # V is number of non blank labels

big_blank_durations = [2,3,4,5,6,7,8]
big_blank_durations = [2]
# big_blank_durations = []
sigma = 0.1
args = {}
args['big_blank_durations'] = big_blank_durations
args['sigma'] = sigma

args2 = {}
args2['big_blank_durations'] = big_blank_durations
args2['sigma'] = sigma

Loss = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt_pytorch', loss_kwargs=args)
Loss2 = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt', loss_kwargs=args2) if len(big_blank_durations) > 0 else RNNTLoss(V, reduction='mean_batch', loss_name='warprnnt_numba', loss_kwargs=args2)

for t in range(22):
acts = torch.rand([B, T, U, V + 1 + len(big_blank_durations)]) - 0.5
acts = torch.nn.Parameter(acts * 5, requires_grad=True)

labels = torch.randint(low=0, high=V - 1, size=[B, U])
act_lens = torch.randint(low=1, high=T + 1, size=[B])
label_lens = torch.randint(low=1, high=U + 1, size=[B]) - 1
act_lens[0] = T
label_lens[0] = U - 1
logits = acts

logits = logits.cuda()
labels = labels.cuda()
act_lens = act_lens.cuda()
label_lens = label_lens.cuda()

labels = labels.contiguous()

loss = Loss(log_probs=logits, targets=labels, input_lengths=act_lens, target_lengths=label_lens)
loss = torch.mean(loss)
loss.backward()
grad1 = torch.clone(acts.grad)
acts.grad *= 0.0

loss2 = Loss2(log_probs=logits, targets=labels, input_lengths=act_lens, target_lengths=label_lens)

loss2.backward()

print("loss diff", float(loss - loss2), float(loss), float(loss2))
print("grad norm diff per element", float(torch.norm(acts.grad - grad1)))

68 changes: 43 additions & 25 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, big_blank_id_lis
super(AbstractRNNTDecoding, self).__init__()
self.cfg = decoding_cfg
self.blank_id = blank_id
self.big_blank_id_list = big_blank_id_list
self.big_blank_duration_list = big_blank_duration_list
self.big_blank_durations = self.cfg.get("big_blank_durations", None)
self.compute_hypothesis_token_set = self.cfg.get("compute_hypothesis_token_set", False)
self.preserve_alignments = self.cfg.get('preserve_alignments', None)
self.joint_fused_batch_size = self.cfg.get('fused_batch_size', None)
Expand Down Expand Up @@ -162,30 +161,50 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, big_blank_id_lis
raise ValueError("If `compute_timesteps` flag is set, then `preserve_alignments` flag must also be set.")

if self.cfg.strategy == 'greedy':

self.decoding = greedy_decode.GreedyRNNTInfer(
decoder_model=decoder,
joint_model=joint,
blank_index=self.blank_id,
big_blank_duration_list=self.big_blank_duration_list,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None)
),
preserve_alignments=self.preserve_alignments,
)
if self.big_blank_durations is not None:
self.decoding = greedy_decode.GreedyMultiblankRNNTInfer(
decoder_model=decoder,
joint_model=joint,
blank_index=self.blank_id,
big_blank_durations=self.big_blank_durations,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None)
),
preserve_alignments=self.preserve_alignments,
)
else:
self.decoding = greedy_decode.GreedyRNNTInfer(
decoder_model=decoder,
joint_model=joint,
blank_index=self.blank_id,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None)
),
preserve_alignments=self.preserve_alignments,
)

elif self.cfg.strategy == 'greedy_batch':

self.decoding = greedy_decode.GreedyBatchedRNNTInfer(
decoder_model=decoder,
joint_model=joint,
blank_index=self.blank_id,
big_blank_duration_list=self.big_blank_duration_list,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None)
),
preserve_alignments=self.preserve_alignments,
)
if self.big_blank_durations is not None:
self.decoding = greedy_decode.GreedyBatchedMultiblankRNNTInfer(
decoder_model=decoder,
joint_model=joint,
blank_index=self.blank_id,
big_blank_durations=self.big_blank_durations,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None)
),
preserve_alignments=self.preserve_alignments,
)
else:
self.decoding = greedy_decode.GreedyBatchedRNNTInfer(
decoder_model=decoder,
joint_model=joint,
blank_index=self.blank_id,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None) or self.cfg.greedy.get('max_symbols_per_step', None)
),
preserve_alignments=self.preserve_alignments,
)

elif self.cfg.strategy == 'beam':

Expand Down Expand Up @@ -929,7 +948,6 @@ def compute(self):
@dataclass
class RNNTDecodingConfig:
strategy: str = "greedy_batch"
duration: str = ""
compute_hypothesis_token_set: bool = False

# preserve decoding alignments
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))

self.loss = RNNTLoss(
num_classes=self.joint.num_classes_with_blank - 1 - self.joint.num_extra_outputs, loss_name=loss_name, loss_kwargs=loss_kwargs
num_classes=self.joint.num_classes_with_blank - 1 - self.joint.num_extra_outputs,
loss_name=loss_name,
loss_kwargs=loss_kwargs,
)

if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None:
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,13 @@ def rnnt_loss_gpu(
num_threads = max(1, num_threads) # have to use at least 1 thread

gpu_size, status = rnnt_helper.get_workspace_size(maxT, maxU, minibatch_size, gpu=True)

if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Invalid parameter passed when calculating working space memory")

# Select GPU index
cuda.select_device(acts.device.index)
gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False)


### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ###
acts, acts_shape = rnnt_helper.flatten_tensor(acts)

Expand Down Expand Up @@ -238,8 +236,6 @@ def rnnt_loss_gpu(
return True




def multiblank_rnnt_loss_gpu(
acts: torch.Tensor,
labels: torch.Tensor,
Expand All @@ -266,10 +262,15 @@ def multiblank_rnnt_loss_gpu(
costs: Zero vector of length [B] in which costs will be set.
grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set.
blank_label: Index of the blank token in the vocabulary.
big_blank_durations: A list of supported durations for big blank symbols
in the model, e.g. [2, 4, 8]. Note we only include durations for ``big
blanks'' here thus it does not include 1 for the standard blank.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
num_threads: Number of threads for OpenMP.
sigma: logit-undernormalization weight used in the multi-blank model. Refer to
the multi-blank paper https://arxiv.org/pdf/2211.03541 for detailed explanations.
"""
minibatch_size = acts.shape[0]
maxT = acts.shape[1]
Expand Down
Loading

0 comments on commit 881240c

Please sign in to comment.