Skip to content

Commit

Permalink
bug seeems fixed
Browse files Browse the repository at this point in the history
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
Hainan Xu committed Nov 29, 2022
1 parent 19b93b2 commit 02eb07b
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 44 deletions.
22 changes: 8 additions & 14 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def __init__(self, blank, big_blank_durations, reduction, sigma):

@typecheck()
def forward(self, acts, labels, act_lens, label_lens):
# print("pytorch sigma", self.sigma)
acts = torch.log_softmax(acts, -1) - self.sigma
forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens)
return -forward_logprob
Expand All @@ -248,12 +247,10 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens):
if t == 0:
log_alpha[:, t, u] = 0.0
else:
# print("pytorch scores", log_alpha[:, t-1, u], acts[:, t-1, 0, self.blank] )
log_alpha[:, t, u] = log_alpha[:, t-1, u] + acts[:, t-1, 0, self.blank]

for i, d in enumerate(self.big_blank_durations):
if t >= d:
# print("pytorch scores", log_alpha[:, t - d, u], acts[:, t - d, 0, self.blank + 1 + i])
tt = log_alpha[:, t - d, u] + acts[:, t - d, 0, self.blank + 1 + i]
log_alpha[:, t, u] = torch.logsumexp(torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0)

Expand All @@ -271,7 +268,6 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens):
if t >= d:
tt = log_alpha[:, t - d, u] + acts[:, t - d, u, self.blank + 1 + i]
log_alpha[:, t, u] = torch.logsumexp(torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0)
# print("pytorch alpha", 0, t, u, float(log_alpha[0, t, u]))

log_probs = []
for b in range(B):
Expand Down Expand Up @@ -423,11 +419,12 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
B, T, U, V = 1, 11, 7, 5
B, T, U, V = 1, 2, 2, 2 # V is number of non blank labels
B, T, U, V = 1, 3, 3, 5 # V is number of non blank labels
B, T, U, V = 8, 64, 32, 128 # V is number of non blank labels
B, T, U, V = 8, 264, 32, 128 # V is number of non blank labels

big_blank_durations = [2, 4, 8]
big_blank_idx_list=list(range(V + 1, V + 1 + len(big_blank_durations)))
sigma = 0.05
big_blank_durations = [2,3,4,5,6,7,8]
big_blank_durations = [2]
# big_blank_durations = []
sigma = 0.1
args = {}
args['big_blank_durations'] = big_blank_durations
args['sigma'] = sigma
Expand All @@ -437,7 +434,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
args2['sigma'] = sigma

Loss = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt_pytorch', loss_kwargs=args)
Loss2 = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt', loss_kwargs=args2)
Loss2 = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt', loss_kwargs=args2) if len(big_blank_durations) > 0 else RNNTLoss(V, reduction='mean_batch', loss_name='warprnnt_numba', loss_kwargs=args2)

for t in range(22):
acts = torch.rand([B, T, U, V + 1 + len(big_blank_durations)]) - 0.5
Expand Down Expand Up @@ -467,9 +464,6 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):

loss2.backward()

print("loss diff", float(loss - loss2))
print("grad norm diff per element", float(torch.norm(acts.grad - grad1) / (B * T * U * V)))
# print("they are")
# print(acts.grad)
# print(grad1)
print("loss diff", float(loss - loss2), float(loss), float(loss2))
print("grad norm diff per element", float(torch.norm(acts.grad - grad1)))

4 changes: 2 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))

self.loss = RNNTLoss(
num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs
num_classes=self.joint.num_classes_with_blank - 1 - self.joint.num_extra_outputs, loss_name=loss_name, loss_kwargs=loss_kwargs
)

if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None:
Expand Down Expand Up @@ -341,7 +341,7 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di
del self.loss
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get('loss', None))
self.loss = RNNTLoss(
num_classes=self.joint.num_classes_with_blank - 1 - self.num_big_blanks, loss_name=loss_name, loss_kwargs=loss_kwargs
num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs
)

if decoding_cfg is None:
Expand Down
5 changes: 2 additions & 3 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,10 @@ def multiblank_rnnt_loss_gpu(
big_blank_labels = [blank_label + 1 + i for i in range(len(big_blank_durations))]
merged_list = list(big_blank_labels) + list(big_blank_durations)

big_blank_workspace = torch.zeros(2 * len(big_blank_labels), device=acts.device, dtype=torch.long, requires_grad=False)
big_blank_workspace = torch.zeros(len(big_blank_labels), device=acts.device, dtype=torch.long, requires_grad=False)

for i in range(0, len(big_blank_labels)):
big_blank_workspace[i] = big_blank_labels[i]
big_blank_workspace[i + len(big_blank_labels)] = big_blank_durations[i]
big_blank_workspace[i] = big_blank_durations[i]

### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ###
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def compute_cost_and_score(
if training:
grads *= 0.0 # zero grads

used_offset, (denom, alphas, betas, llForward, llBackward, BBLabels, BBDuations) = self._prepare_workspace()
used_offset, (denom, alphas, betas, llForward, llBackward, bigblank_durations) = self._prepare_workspace()

######## START EXECUTION ########
self.log_softmax(acts, denom)
Expand All @@ -406,8 +406,7 @@ def compute_cost_and_score(
self.maxU_,
self.alphabet_size_,
self.blank_,
BBLabels,
BBDuations,
bigblank_durations,
self.num_big_blanks,
)

Expand All @@ -426,8 +425,7 @@ def compute_cost_and_score(
self.maxU_,
self.alphabet_size_,
self.blank_,
BBLabels,
BBDuations,
bigblank_durations,
self.num_big_blanks,
)

Expand All @@ -449,8 +447,7 @@ def compute_cost_and_score(
self.maxU_,
self.alphabet_size_,
self.blank_,
BBLabels,
BBDuations,
bigblank_durations,
self.num_big_blanks,
self.fastemit_lambda_,
self.clamp_,
Expand Down Expand Up @@ -531,7 +528,6 @@ def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]):
llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_

BBLabels = self.big_blank_workspace[0:self.num_big_blanks]
BBDurations = self.big_blank_workspace[self.num_big_blanks:2* self.num_big_blanks]
bigblank_durations = self.big_blank_workspace[:self.num_big_blanks]

return used_offset, (denom, alphas, betas, llForward, llBackward, BBLabels, BBDurations,)
return used_offset, (denom, alphas, betas, llForward, llBackward, bigblank_durations)
Loading

0 comments on commit 02eb07b

Please sign in to comment.