Skip to content

Commit

Permalink
Improve mAES algorithm with patches (#4662)
Browse files Browse the repository at this point in the history
* First draft implementation fix of maes

Signed-off-by: smajumdar <titu1994@gmail.com>

* add deduplication checks

Signed-off-by: smajumdar <smajumdar@nvidia.com>
  • Loading branch information
titu1994 authored Aug 3, 2022
1 parent 5c8fe3a commit 498ff20
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
40 changes: 27 additions & 13 deletions nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ def __init__(
if self.maes_prefix_alpha < 0:
raise ValueError("`maes_prefix_alpha` must be a positive integer.")

if self.vocab_size < beam_size + maes_expansion_beta:
raise ValueError(
f"beam_size ({beam_size}) + expansion_beta ({maes_expansion_beta}) "
f"should be smaller or equal to vocabulary size ({self.vocab_size})."
)

if search_type == 'maes':
self.max_candidates += maes_expansion_beta

if self.maes_num_steps < 2:
raise ValueError("`maes_num_steps` must be greater than 1.")

Expand Down Expand Up @@ -989,21 +998,24 @@ def modified_adaptive_expansion_search(

# List that contains the blank token emisions
list_b = []
duplication_check = [hyp.y_sequence for hyp in hyps]

# Repeat for number of mAES steps
for n in range(self.maes_num_steps):
# Pack the decoder logits for all current hypothesis
beam_dec_out = torch.stack([h.dec_out[-1] for h in hyps]) # [H, 1, D]

# Extract the log probabilities
beam_logp = torch.log_softmax(
beam_logp, beam_idx = torch.log_softmax(
self.joint.joint(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1,
)
).topk(self.max_candidates, dim=-1)

beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1]
beam_idx = beam_idx[:, 0, 0, :] # [B, max_candidates]

# Compute k expansions for all the current hypotheses
k_expansions = select_k_expansions(
hyps, beam_logp, beam, self.maes_expansion_gamma, self.maes_expansion_beta
hyps, beam_idx, beam_logp, self.maes_expansion_gamma, self.maes_expansion_beta
)

# List that contains the hypothesis after prefix expansion
Expand All @@ -1024,16 +1036,18 @@ def modified_adaptive_expansion_search(
list_b.append(new_hyp)
else:
# If the expansion was a token
new_hyp.y_sequence.append(int(k))

# TODO: Setup LM
if self.language_model is not None:
# new_hyp.score += self.lm_weight * float(
# hyp.lm_scores[k]
# )
pass

list_exp.append(new_hyp)
# new_hyp.y_sequence.append(int(k))
if (new_hyp.y_sequence + [int(k)]) not in duplication_check:
new_hyp.y_sequence.append(int(k))

# TODO: Setup LM
if self.language_model is not None:
# new_hyp.score += self.lm_weight * float(
# hyp.lm_scores[k]
# )
pass

list_exp.append(new_hyp)

# If there were no token expansions in any of the hypotheses,
# Early exit
Expand Down
12 changes: 5 additions & 7 deletions nemo/collections/asr/parts/utils/rnnt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def is_prefix(x: List[int], pref: List[int]) -> bool:


def select_k_expansions(
hyps: List[Hypothesis], logps: torch.Tensor, beam_size: int, gamma: float, beta: int,
hyps: List[Hypothesis], topk_idxs: torch.Tensor, topk_logps: torch.Tensor, gamma: float, beta: int,
) -> List[Tuple[int, Hypothesis]]:
"""
Obtained from https://github.com/espnet/espnet
Expand All @@ -128,8 +128,8 @@ def select_k_expansions(
Args:
hyps: Hypotheses.
beam_logp: Log-probabilities for hypotheses expansions.
beam_size: Beam size.
topk_idxs: Indices of candidates hypothesis. Shape = [B, num_candidates]
topk_logps: Log-probabilities for hypotheses expansions. Shape = [B, V + 1]
gamma: Allowed logp difference for prune-by-value method.
beta: Number of additional candidates to store.
Expand All @@ -139,15 +139,13 @@ def select_k_expansions(
k_expansions = []

for i, hyp in enumerate(hyps):
hyp_i = [(int(k), hyp.score + float(logp)) for k, logp in enumerate(logps[i])]
hyp_i = [(int(k), hyp.score + float(v)) for k, v in zip(topk_idxs[i], topk_logps[i])]
k_best_exp_val = max(hyp_i, key=lambda x: x[1])

k_best_exp_idx = k_best_exp_val[0]
k_best_exp = k_best_exp_val[1]

expansions = sorted(filter(lambda x: (k_best_exp - gamma) <= x[1], hyp_i), key=lambda x: x[1],)[
: beam_size + beta
]
expansions = sorted(filter(lambda x: (k_best_exp - gamma) <= x[1], hyp_i), key=lambda x: x[1],)

if len(expansions) > 0:
k_expansions.append(expansions)
Expand Down

0 comments on commit 498ff20

Please sign in to comment.