From 02eb07be7711f06017d3566324f94ac583807a9c Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Sat, 19 Nov 2022 22:34:12 -0500 Subject: [PATCH] bug seeems fixed Signed-off-by: Hainan Xu --- nemo/collections/asr/losses/rnnt.py | 22 +- nemo/collections/asr/models/rnnt_models.py | 4 +- .../asr/parts/numba/rnnt_loss/rnnt.py | 5 +- .../rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 16 +- .../utils/cuda_utils/gpu_rnnt_kernel.py | 362 +++++++++++++++++- 5 files changed, 365 insertions(+), 44 deletions(-) diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index f9a8e8f2e8f4..9b5618c082f0 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -232,7 +232,6 @@ def __init__(self, blank, big_blank_durations, reduction, sigma): @typecheck() def forward(self, acts, labels, act_lens, label_lens): -# print("pytorch sigma", self.sigma) acts = torch.log_softmax(acts, -1) - self.sigma forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens) return -forward_logprob @@ -248,12 +247,10 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): if t == 0: log_alpha[:, t, u] = 0.0 else: -# print("pytorch scores", log_alpha[:, t-1, u], acts[:, t-1, 0, self.blank] ) log_alpha[:, t, u] = log_alpha[:, t-1, u] + acts[:, t-1, 0, self.blank] for i, d in enumerate(self.big_blank_durations): if t >= d: -# print("pytorch scores", log_alpha[:, t - d, u], acts[:, t - d, 0, self.blank + 1 + i]) tt = log_alpha[:, t - d, u] + acts[:, t - d, 0, self.blank + 1 + i] log_alpha[:, t, u] = torch.logsumexp(torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0) @@ -271,7 +268,6 @@ def compute_forward_prob(self, acts, labels, act_lens, label_lens): if t >= d: tt = log_alpha[:, t - d, u] + acts[:, t - d, u, self.blank + 1 + i] log_alpha[:, t, u] = torch.logsumexp(torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0) -# print("pytorch alpha", 0, t, u, float(log_alpha[0, t, u])) log_probs = [] for b in range(B): @@ -423,11 +419,12 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): B, T, U, V = 1, 11, 7, 5 B, T, U, V = 1, 2, 2, 2 # V is number of non blank labels B, T, U, V = 1, 3, 3, 5 # V is number of non blank labels - B, T, U, V = 8, 64, 32, 128 # V is number of non blank labels + B, T, U, V = 8, 264, 32, 128 # V is number of non blank labels - big_blank_durations = [2, 4, 8] - big_blank_idx_list=list(range(V + 1, V + 1 + len(big_blank_durations))) - sigma = 0.05 + big_blank_durations = [2,3,4,5,6,7,8] + big_blank_durations = [2] +# big_blank_durations = [] + sigma = 0.1 args = {} args['big_blank_durations'] = big_blank_durations args['sigma'] = sigma @@ -437,7 +434,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): args2['sigma'] = sigma Loss = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt_pytorch', loss_kwargs=args) - Loss2 = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt', loss_kwargs=args2) + Loss2 = RNNTLoss(V, reduction='mean_batch', loss_name='multiblank_rnnt', loss_kwargs=args2) if len(big_blank_durations) > 0 else RNNTLoss(V, reduction='mean_batch', loss_name='warprnnt_numba', loss_kwargs=args2) for t in range(22): acts = torch.rand([B, T, U, V + 1 + len(big_blank_durations)]) - 0.5 @@ -467,9 +464,6 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): loss2.backward() - print("loss diff", float(loss - loss2)) - print("grad norm diff per element", float(torch.norm(acts.grad - grad1) / (B * T * U * V))) -# print("they are") -# print(acts.grad) -# print(grad1) + print("loss diff", float(loss - loss2), float(loss), float(loss2)) + print("grad norm diff per element", float(torch.norm(acts.grad - grad1))) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 1e02287a35cd..37732a749886 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -72,7 +72,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) self.loss = RNNTLoss( - num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs + num_classes=self.joint.num_classes_with_blank - 1 - self.joint.num_extra_outputs, loss_name=loss_name, loss_kwargs=loss_kwargs ) if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: @@ -341,7 +341,7 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di del self.loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get('loss', None)) self.loss = RNNTLoss( - num_classes=self.joint.num_classes_with_blank - 1 - self.num_big_blanks, loss_name=loss_name, loss_kwargs=loss_kwargs + num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs ) if decoding_cfg is None: diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index bcf489b9e998..f984c30ec39b 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -298,11 +298,10 @@ def multiblank_rnnt_loss_gpu( big_blank_labels = [blank_label + 1 + i for i in range(len(big_blank_durations))] merged_list = list(big_blank_labels) + list(big_blank_durations) - big_blank_workspace = torch.zeros(2 * len(big_blank_labels), device=acts.device, dtype=torch.long, requires_grad=False) + big_blank_workspace = torch.zeros(len(big_blank_labels), device=acts.device, dtype=torch.long, requires_grad=False) for i in range(0, len(big_blank_labels)): - big_blank_workspace[i] = big_blank_labels[i] - big_blank_workspace[i + len(big_blank_labels)] = big_blank_durations[i] + big_blank_workspace[i] = big_blank_durations[i] ### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ### acts, acts_shape = rnnt_helper.flatten_tensor(acts) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index 4c811dcf133c..1f167abe3ed5 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -387,7 +387,7 @@ def compute_cost_and_score( if training: grads *= 0.0 # zero grads - used_offset, (denom, alphas, betas, llForward, llBackward, BBLabels, BBDuations) = self._prepare_workspace() + used_offset, (denom, alphas, betas, llForward, llBackward, bigblank_durations) = self._prepare_workspace() ######## START EXECUTION ######## self.log_softmax(acts, denom) @@ -406,8 +406,7 @@ def compute_cost_and_score( self.maxU_, self.alphabet_size_, self.blank_, - BBLabels, - BBDuations, + bigblank_durations, self.num_big_blanks, ) @@ -426,8 +425,7 @@ def compute_cost_and_score( self.maxU_, self.alphabet_size_, self.blank_, - BBLabels, - BBDuations, + bigblank_durations, self.num_big_blanks, ) @@ -449,8 +447,7 @@ def compute_cost_and_score( self.maxU_, self.alphabet_size_, self.blank_, - BBLabels, - BBDuations, + bigblank_durations, self.num_big_blanks, self.fastemit_lambda_, self.clamp_, @@ -531,7 +528,6 @@ def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]): llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_] used_offset += self.minibatch_ - BBLabels = self.big_blank_workspace[0:self.num_big_blanks] - BBDurations = self.big_blank_workspace[self.num_big_blanks:2* self.num_big_blanks] + bigblank_durations = self.big_blank_workspace[:self.num_big_blanks] - return used_offset, (denom, alphas, betas, llForward, llBackward, BBLabels, BBDurations,) + return used_offset, (denom, alphas, betas, llForward, llBackward, bigblank_durations) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 78e03f55b9c9..1bdd26e5dcc6 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -77,7 +77,344 @@ def compute_alphas_kernel( maxU: int, alphabet_size: int, blank_: int, - big_blank_number: torch.Tensor, +): + """ + Compute alpha (forward variable) probabilities over the transduction step. + + Args: + acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor + across entire vocabulary. + alphas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable + probabilities. + llForward: Zero tensor of shape [B]. Represents the log-likelihood of the forward pass. + Returned as the forward pass loss that is reduced by the optimizer. + xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded + activation tensor. + ylen: Vector of length B which contains the actual target sequence lengths in the padded + activation tensor. + mlabels: Matrix of shape [B, U+1] (+1 here is due to token - usually the RNNT blank). + The matrix contains the padded target transcription that must be predicted. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). + blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + + Updates: + Kernel inplace updates the following inputs: + - alphas: forward variable scores. + - llForward: log-likelihood of forward variable. + """ + # // launch B blocks, each block has U threads + b = cuda.blockIdx.x # // batch id + u = cuda.threadIdx.x # label id, u + T = xlen[b] # select AM length of current sample + U = ylen[b] + 1 # select target length of current sample, +1 for the blank token + + labels: torch.Tensor = mlabels[b] # mb label start point, equivalent to mlabels + b * (maxU - 1) + offset = b * maxT * maxU # pointer indexing offset + + # alphas += offset # pointer offset, ignored since we explicitly add offset + + # Initilize alpha[b, t=0, u=0] for all b in B + if u == 0: + alphas[offset] = 0 + + # sync until all alphas are initialized + cuda.syncthreads() + + # Ordinary alpha calculations, broadcast across B=b and U=u + # Look up forward variable calculation from rnnt_numpy.forward_pass() + for n in range(1, T + U - 1): + t = n - u + + if u == 0: + # for t in range(1, T) step to initialize alphas[b, t, 0] + if t > 0 and t < T: + alphas[offset + t * maxU + u] = alphas[offset + (t - 1) * maxU + u] + logp( + denom, acts, maxT, maxU, alphabet_size, b, t - 1, 0, blank_ + ) + elif u < U: + # for u in range(1, U) step to initialize alphas[b, 0, u] + if t == 0: + alphas[offset + u] = alphas[offset + u - 1] + logp( + denom, acts, maxT, maxU, alphabet_size, b, 0, u - 1, labels[u - 1] + ) + + # for t in range(1, T) for u in range(1, U) step to compute alphas[b, t, u] + elif t > 0 and t < T: + no_emit = alphas[offset + (t - 1) * maxU + u] + logp( + denom, acts, maxT, maxU, alphabet_size, b, t - 1, u, blank_ + ) + emit = alphas[offset + t * maxU + u - 1] + logp( + denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1] + ) + + alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) + + # sync across all B=b and U=u + cuda.syncthreads() + + # After final sync, alphas[b, T-1, U - 1] + logprobs[b, T-1, U-1, blank] + denom[b, T-1, U-1] gives + # log-likelihood of forward pass. + if u == 0: + loglike = alphas[offset + (T - 1) * maxU + U - 1] + logp( + denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_ + ) + llForward[b] = loglike + + +@cuda.jit() +def compute_betas_kernel( + acts: torch.Tensor, + denom: torch.Tensor, + betas: torch.Tensor, + llBackward: torch.Tensor, + xlen: torch.Tensor, + ylen: torch.Tensor, + mlabels: torch.Tensor, # [B, U] + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + blank_: int, +): + """ + Compute beta (backward variable) probabilities over the transduction step. + + Args: + acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor + across entire vocabulary. + betas: Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the backward variable + probabilities. + llBackward: Zero tensor of shape [B]. Represents the log-likelihood of the backward pass. + Returned as the backward pass loss that is reduced by the optimizer. + xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded + activation tensor. + ylen: Vector of length B which contains the actual target sequence lengths in the padded + activation tensor. + mlabels: Matrix of shape [B, U+1] (+1 here is due to token - usually the RNNT blank). + The matrix contains the padded target transcription that must be predicted. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). + blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + + Updates: + Kernel inplace updates the following inputs: + - betas: backward variable scores. + - llBackward: log-likelihood of backward variable. + """ + # // launch B blocks, each block has U threads + b = cuda.blockIdx.x # // batch id + u = cuda.threadIdx.x # label id, u + T = xlen[b] # select AM length of current sample + U = ylen[b] + 1 # select target length of current sample, +1 for the blank token + + labels: torch.Tensor = mlabels[b] # mb label start point, equivalent to mlabels + b * (maxU - 1) + offset = b * maxT * maxU # pointer indexing offset + + # betas += offset # pointer offset, ignored since we explicitly add offset + + # Initilize beta[b, t=T-1, u=U-1] for all b in B with log_probs[b, t=T-1, u=U-1, blank] + if u == 0: + betas[offset + (T - 1) * maxU + U - 1] = logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) + + # sync until all betas are initialized + cuda.syncthreads() + + # Ordinary beta calculations, broadcast across B=b and U=u + # Look up backward variable calculation from rnnt_numpy.backward_pass() + for n in range(T + U - 2, -1, -1): + t = n - u + + if u == (U - 1): + # for t in reversed(range(T - 1)) step to initialize betas[b, t, U-1] + if t >= 0 and t < (T - 1): + betas[offset + t * maxU + U - 1] = betas[offset + (t + 1) * maxU + U - 1] + logp( + denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_ + ) + elif u < U: + if t == T - 1: + # for u in reversed(range(U - 1)) step to initialize betas[b, T-1, u] + betas[offset + (T - 1) * maxU + u] = betas[offset + (T - 1) * maxU + u + 1] + logp( + denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u] + ) + elif (t >= 0) and (t < T - 1): + # for t in reversed(range(T - 1)) for u in reversed(range(U - 1)) step to compute betas[b, t, u] + no_emit = betas[offset + (t + 1) * maxU + u] + logp( + denom, acts, maxT, maxU, alphabet_size, b, t, u, blank_ + ) + emit = betas[offset + t * maxU + u + 1] + logp( + denom, acts, maxT, maxU, alphabet_size, b, t, u, labels[u] + ) + betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp(emit, no_emit) + + # sync across all B=b and U=u + cuda.syncthreads() + + # After final sync, betas[b, 0, 0] gives + # log-likelihood of backward pass. + if u == 0: + llBackward[b] = betas[offset] + + +@cuda.jit() +def compute_grad_kernel( + grads: torch.Tensor, + acts: torch.Tensor, + denom: torch.Tensor, + alphas: torch.Tensor, + betas: torch.Tensor, + logll: torch.Tensor, + xlen: torch.Tensor, + ylen: torch.Tensor, + mlabels: torch.Tensor, # [B, U] + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + blank_: int, + fastemit_lambda: float, + clamp: float, +): + """ + Compute gradients over the transduction step. + + Args: + grads: Zero Tensor of shape [B, T, U, V+1]. Is updated by this kernel to contain the gradients + of this batch of samples. + acts: Tensor of shape [B, T, U, V+1] flattened. Represents the logprobs activation tensor. + denom: Tensor of shape [B, T, U] flattened. Represents the denominator of the logprobs activation tensor + across entire vocabulary. + alphas: Alpha variable, contains forward probabilities. A tensor of shape [B, T, U]. + betas: Beta varoable, contains backward probabilities. A tensor of shape [B, T, U]. + logll: Log-likelihood of the forward variable, represented as a vector of shape [B]. + Represents the log-likelihood of the forward pass. + xlen: Vector of length B which contains the actual acoustic sequence lengths in the padded + activation tensor. + ylen: Vector of length B which contains the actual target sequence lengths in the padded + activation tensor. + mlabels: Matrix of shape [B, U+1] (+1 here is due to token - usually the RNNT blank). + The matrix contains the padded target transcription that must be predicted. + minibatch: Int representing the batch size. + maxT: The maximum possible acoustic sequence length. Represents T in the logprobs tensor. + maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. + alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). + blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. + + Updates: + Kernel inplace updates the following inputs: + - grads: Gradients with respect to the log likelihood (logll). + """ + # Kernel call: + # blocks_per_grid = minibatch (b) * maxT (t) * maxU (u) + # threads_per_block = constant buffer size of parallel threads (v :: Constant) + tid = cuda.threadIdx.x # represents v, taking steps of some constant size + idx = tid # index of v < V+1; in steps of constant buffer size + col = cuda.blockIdx.x # represents a fused index of b * t * u + + # Decompose original indices from fused `col` + u = col % maxU # (b * t * u) % u = u + bt = (col - u) // maxU # (b * t * u - u) // U = b * t + t = bt % maxT # (b * t) % t = t + mb = (bt - t) // maxT # (b * t - t) // T = b + + # constants + T = xlen[mb] # select AM length of current sample + U = ylen[mb] + 1 # select target length of current sample, +1 for the blank token + labels: torch.Tensor = mlabels[mb] # labels = mlabels + mb * (maxU - 1); + + # Buffered gradient calculations, broadcast across B=b, T=t and U=u, looped over V with some constant stride. + # Look up gradient calculation from rnnt_numpy.compute_gradient() + if t < T and u < U: + # For cuda kernels, maximum number of threads per block is limited to some value. + # However, it may be the case that vocabulary size is larger than this limit + # To work around this, an arbitrary thread buffer size is chosen such that, + # 1) each element within the thread pool operates independently of the other + # 2) An inner while loop moves the index of each buffer element by the size of the buffer itself, + # such that all elements of the vocabulary size are covered in (V + 1 // thread_buffer) number of steps. + # As such, each thread will perform the while loop at least (V + 1 // thread_buffer) number of times + while idx < alphabet_size: + # remember, `col` represents the tri-index [b, t, u] + # therefore; logpk = denom[b, t, u] + acts[b, t, u, v] + logpk = denom[col] + acts[col * alphabet_size + idx] + # initialize the grad of the sample acts[b, t, u, v] + grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) + + # If FastEmit regularization is enabled, calculate the gradeint of probability of predicting the next label + # at the current timestep. + # The formula for this is Equation 9 in https://arxiv.org/abs/2010.11148, multiplied by the log probability + # of the current step (t, u), normalized by the total log likelihood. + # Once the gradient has been calculated, scale it by `fastemit_lambda`, as in Equation 10. + if fastemit_lambda > 0.0 and u < U - 1: + fastemit_grad = fastemit_lambda * math.exp( + alphas[col] # alphas(t, u) + + (denom[col] + acts[col * alphabet_size + labels[u]]) # y_hat(t, u) + + betas[col + 1] # betas(t, u+1) + + logpk # log Pr(k|t, u) + - logll[mb] # total log likelihood for normalization + ) + else: + fastemit_grad = 0.0 + + # Update the gradient of act[b, t, u, v] with the gradient from FastEmit regularization + grad = grad + fastemit_grad + + # // grad to last blank transition + # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) + if (idx == blank_) and (t == T - 1) and (u == U - 1): + grad -= math.exp(alphas[col] + logpk - logll[mb]) + + # grad of blank across t < T; + # grad[b, t 0.0: + g = grads[col * alphabet_size + idx] + g = min(g, clamp) + g = max(g, -clamp) + grads[col * alphabet_size + idx] = g + + # update internal index through the thread_buffer; + # until idx < V + 1, such that entire vocabulary has been updated. + idx += GPU_RNNT_THREAD_SIZE + + +@cuda.jit() +def compute_multiblank_alphas_kernel( + acts: torch.Tensor, + denom: torch.Tensor, + sigma: float, + alphas: torch.Tensor, + llForward: torch.Tensor, + xlen: torch.Tensor, + ylen: torch.Tensor, + mlabels: torch.Tensor, # [B] + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + blank_: int, big_blank_duration: torch.Tensor, num_big_blanks: int, ): @@ -139,16 +476,13 @@ def compute_alphas_kernel( alphas[offset + t * maxU + u] = alphas[offset + (t - 1) * maxU + u] + logp( denom, acts, maxT, maxU, alphabet_size, b, t - 1, 0, blank_ ) - sigma -# print("KERNAL", alphas[offset + (t - 1) * maxU + u], logp( -# denom, acts, maxT, maxU, alphabet_size, b, t - 1, 0, blank_ -# ) - sigma) for i in range(num_big_blanks): if t >= big_blank_duration[i]: alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( alphas[offset + t * maxU + u], alphas[offset + (t - big_blank_duration[i]) * maxU + u] + logp( - denom, acts, maxT, maxU, alphabet_size, b, t - big_blank_duration[i], 0, big_blank_number[i] + denom, acts, maxT, maxU, alphabet_size, b, t - big_blank_duration[i], 0, blank_ + 1 + i ) - sigma ) @@ -173,7 +507,7 @@ def compute_alphas_kernel( for i in range(num_big_blanks): if t >= big_blank_duration[i]: big_blank_no_emit = alphas[offset + (t - big_blank_duration[i]) * maxU + u] + logp( - denom, acts, maxT, maxU, alphabet_size, b, t - big_blank_duration[i], u, big_blank_number[i] + denom, acts, maxT, maxU, alphabet_size, b, t - big_blank_duration[i], u, blank_ + 1 + i ) - sigma alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( alphas[offset + t * maxU + u], @@ -194,7 +528,7 @@ def compute_alphas_kernel( for i in range(num_big_blanks): if T >= big_blank_duration[i]: big_blank_loglike = alphas[offset + (T - big_blank_duration[i]) * maxU + U - 1] + logp( - denom, acts, maxT, maxU, alphabet_size, b, T - big_blank_duration[i], U - 1, big_blank_number[i] + denom, acts, maxT, maxU, alphabet_size, b, T - big_blank_duration[i], U - 1, blank_ + 1 + i ) - sigma loglike = rnnt_helper.log_sum_exp(loglike, big_blank_loglike) @@ -214,7 +548,6 @@ def compute_betas_kernel( maxU: int, alphabet_size: int, blank_: int, - big_blank_number: torch.Tensor, big_blank_duration: torch.Tensor, num_big_blanks: int, ): @@ -264,7 +597,7 @@ def compute_betas_kernel( if T >= big_blank_duration[i] and big_blank_duration[i] == 1: betas[offset + (T - 1) * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + (T - 1) * maxU + U - 1], - logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, big_blank_number[i]) - sigma + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_ + 1 + i) - sigma ) # sync until all betas are initialized @@ -287,14 +620,14 @@ def compute_betas_kernel( betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], betas[offset + (t + big_blank_duration[i]) * maxU + U - 1] + logp( - denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, big_blank_number[i] + denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_ + 1 + i ) - sigma ) elif t + big_blank_duration[i] == T and big_blank_duration[i] != 1: betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], logp( - denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, big_blank_number[i] + denom, acts, maxT, maxU, alphabet_size, b, t, U - 1, blank_ + 1 + i ) - sigma ) @@ -319,7 +652,7 @@ def compute_betas_kernel( for i in range(num_big_blanks): if t < T - big_blank_duration[i]: big_blank_no_emit = betas[offset + (t + big_blank_duration[i]) * maxU + u] + logp( - denom, acts, maxT, maxU, alphabet_size, b, t, u, big_blank_number[i] + denom, acts, maxT, maxU, alphabet_size, b, t, u, blank_ + 1 + i ) - sigma betas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + u], @@ -350,7 +683,6 @@ def compute_grad_kernel( maxU: int, alphabet_size: int, blank_: int, - big_blank_number: torch.Tensor, big_blank_duration: torch.Tensor, num_big_blanks: int, fastemit_lambda: float, @@ -448,7 +780,7 @@ def compute_grad_kernel( grad -= math.exp(alphas[col] + logpk - logll[mb]) else: for i in range(num_big_blanks): - if (idx == big_blank_number[i]) and (t == T - big_blank_duration[i]) and (u == U - 1): + if (idx == blank_ + 1 + i) and (t == T - big_blank_duration[i]) and (u == U - 1): grad -= math.exp(alphas[col] + logpk - logll[mb]) # grad of blank across t < T; @@ -457,7 +789,7 @@ def compute_grad_kernel( grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU]) else: for i in range(num_big_blanks): - if (idx == big_blank_number[i]) and (t < T - big_blank_duration[i]): + if (idx == blank_ + 1 + i) and (t < T - big_blank_duration[i]): grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + big_blank_duration[i] * maxU]) # grad of correct token across u < U;