Skip to content

Commit

Permalink
better merge target generator
Browse files Browse the repository at this point in the history
  • Loading branch information
1-800-BAD-CODE committed Jul 29, 2023
1 parent 3a9f873 commit 5219b39
Showing 1 changed file with 44 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,56 +274,79 @@ class MergeTargetsGenerator:
"U.S." -> "U. S."
"¿U.S." -> "¿U. S."
"a.m." -> "a. m."
"¿U.S." -> "¿U. S." (actually misses this case, currently)
The intent of this class is to allow learning to "merge" spelled-out acronyms which are transcribed by an ASR model
as individual characters. E.g., a typical ASR model will transcribe "the f b i agent" and in post-processing we want
to merge this into "the fbi agent".
Since it is difficult to deduce which tokens will be transcribed as characters, e.g., "f b i", or as contiguous
tokens, e.g., "nato", we randomly choose whether to split an all-uppercase token to learn any potential case.
tokens, e.g., "nato", we split (with some probability) any token which is all upper-case or contains a period after
each char (e.g., 'a.m.').
Args:
post_labels: Post-punctuation tokens, to be ignored in counting character positions
pre_labels: Pre-punctuation tokens, analogous to `post_labels`
p_split: Split acronyms with this probability.
"""

def __init__(self, post_labels: List[str], pre_labels: List[str], p_split: float = 0.5) -> None:
def __init__(self, post_labels: List[str], pre_labels: List[str], p_split: float = 0.8) -> None:
self._p_split = p_split
self._all_punc: Set[str] = set(post_labels + pre_labels)
self._all_punc.discard(NULL_PUNCT_TOKEN)
self._all_punc.discard(ACRONYM_TOKEN)
# The regex we want for splitting is "<optional_punct_token><character><optional_punct_token>"
punc_str = ''.join(re.escape(x) for x in self._all_punc)
# Match zero or one punc tokens on each side of a non-punc character.
# Regex used to split an acronym into constituent chars with optional periods. Match zero or one punc tokens on
# each side of a non-punc character.
self._split_ptn = re.compile(rf"([{punc_str}]?[^{punc_str}][{punc_str}]?)")
self._punc_ptn = re.compile(rf"[{punc_str}]+")
self._rng = np.random.default_rng() # todo seed? doesn't matter

def generate_targets(self, input_text: str) -> Tuple[str, List[int]]:
out_tokens: List[str] = []
processed_tokens: List[str] = []
char_level_targets: List[int] = []
for token in input_text.split():
# Split on multi-char uppercase tokens: FBI, U.S., etc.
if len(token) == 1 or (not token.isupper()) or np.random.rand() <= self._p_split:
out_tokens.append(token)
# Split on multi-char uppercase tokens: FBI, U.S., etc. Also catch lower-cased acronyms, a.m., etc.
#
# If all non-punc chars in the token belong to cased alphabets and are upper-cased, it's an acronym
# If all chars in the token are followed by a period, it's an acronym
# The `if` structure is convoluted, but enables short-circuits around expensive checks.
do_split = False
if len(token) > 1 and self._rng.random() < self._p_split:
# Check if every non-punc token is upper-case, e.g, NATO, FBI, etc.
token_no_punc = self._punc_ptn.sub("", token)
# Verify upper char-by-char, because`"A認".isupper() == True` but `"認".isupper() == False`.
if len(token_no_punc) > 1 and all(x.isupper() for x in token_no_punc):
do_split = True
# Check for punctuated initialisms: a.m., p.m., etc. Candidates must have 4+ chars.
if not do_split and len(token) >= 4:
# This check fails for, e.g., '¿a.m.' but this pattern appears 0 times in news training data
every_other_char = token[1::2]
do_split = all(x == "." for x in every_other_char)

if not do_split:
processed_tokens.append(token)
for char in token:
if char in self._all_punc:
# Don't count punctuation since it'll be removed before example is generated
continue
char_level_targets.append(0)
continue
# This is an acronym that we should split
subtokens = self._split_ptn.findall(token)
out_tokens.extend(subtokens)
# "merge[i]" implies "remove the space between chars i and i+1", e.g.,
# "FBI agent" -> (['F', 'B', 'I', 'agent'], [merge, merge, no, no, no, no, no, no])
# such that we merge after 'F' and 'B' to recover "FBI agent". The reason for using character-level
# targets is to make it easy to align with subword tokens down-stream.
for i, subtoken in enumerate(subtokens):
target = 1 if i < len(subtokens) - 1 else 0
# Each of these tokens should have exactly one non-punc char, as dictated by the regex
char_level_targets.append(target)
out_text = " ".join(out_tokens)
else:
# This is an acronym that we should split
subtokens = self._split_ptn.findall(token)
processed_tokens.extend(subtokens)
# "merge[i]" implies "remove the space between chars i and i+1", e.g.,
# "FBI agent" -> (['F', 'B', 'I', 'agent'], [merge, merge, no, no, no, no, no, no])
# such that we merge after 'F' and 'B' to recover "FBI agent". The reason for using character-level
# targets is to make it easy to align with subword tokens down-stream.
for i, subtoken in enumerate(subtokens):
target = 1 if i < len(subtokens) - 1 else 0
# Each of these tokens should have exactly one non-punc char, as dictated by the regex
char_level_targets.append(target)
out_text = " ".join(processed_tokens)
return out_text, char_level_targets


Expand Down

0 comments on commit 5219b39

Please sign in to comment.