Skip to content

Commit

Permalink
[canary] Test for CanaryTokenizer + refactoring (#8285)
Browse files Browse the repository at this point in the history
* Test for CanaryTokenizer

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Attempt at refactor...

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
pzelasko authored Jan 31, 2024
1 parent a595213 commit 9b7aa0f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 5 deletions.
9 changes: 5 additions & 4 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig):
self.AGGREGATE_TOKENIZERS_DICT_PREFIX
][lang]['type']

if tokenizer_cfg.get('is_canary', False):
# CanaryTokenizer easy access to spl_tokens which aggegatate
# doesn't have for now; TODO: merge both later
self.tokenizer = tokenizers.CanaryTokenizer(tokenizers_dict)
if "custom_tokenizer" in tokenizer_cfg:
# Class which implements this is usually a ModelPT, has access to Serializable mixin by extension
self.tokenizer = self.from_config_dict(
{"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "tokenizers": tokenizers_dict}
)
else:
self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict)

Expand Down
25 changes: 24 additions & 1 deletion nemo/collections/common/tokenizers/canary_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

from functools import cached_property
from pathlib import Path
from typing import Dict

from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model

__all__ = ['CanaryTokenizer']

Expand All @@ -32,6 +34,7 @@
}

SPECIAL_TOKENS = [
"<pad>",
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())],
Expand All @@ -40,9 +43,10 @@
"<|nopnc|>",
"<|pnc|>",
"<|nospeech|>",
"<pad>",
]

UNUSED_SPECIAL_TOKENS = [f"<|spltoken{i}|>" for i in range(18)]


class CanaryTokenizer(AggregateTokenizer):
"""
Expand Down Expand Up @@ -96,3 +100,22 @@ def to_language_id(self, language):
return token_id

raise KeyError(f"Language {language} not found in tokenizer.")

@staticmethod
def build_special_tokenizer(output_dir: str | Path) -> SentencePieceTokenizer:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
text_path = output_dir / "train_text.txt"
all_tokens = SPECIAL_TOKENS + UNUSED_SPECIAL_TOKENS
train_text = "\n".join(all_tokens)
text_path.write_text(train_text)
model_path = output_dir / "tokenizer.model"
create_spt_model(
str(text_path),
vocab_size=32,
sample_size=-1,
do_lower_case=False,
output_dir=str(output_dir),
user_defined_symbols=all_tokens,
)
return SentencePieceTokenizer(str(model_path))
69 changes: 69 additions & 0 deletions tests/collections/asr/test_custom_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest.mock import Mock

import pytest
import sentencepiece as spm

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'spm' is not used.
from omegaconf import OmegaConf

from nemo.collections.asr.parts.mixins import ASRBPEMixin
from nemo.collections.common.tokenizers.canary_tokenizer import SPECIAL_TOKENS, UNUSED_SPECIAL_TOKENS, CanaryTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model
from nemo.core import Serialization


@pytest.fixture(scope="session")
def special_tokenizer_path(tmp_path_factory) -> str:
tmpdir = tmp_path_factory.mktemp("spl_tokens")
CanaryTokenizer.build_special_tokenizer(tmpdir)
return str(tmpdir)


@pytest.fixture(scope="session")
def lang_tokenizer_path(tmp_path_factory) -> str:
tmpdir = tmp_path_factory.mktemp("klingon_tokens")
text_path = tmpdir / "text.txt"
text_path.write_text("a\nb\nc\nd\n")
create_spt_model(text_path, vocab_size=8, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir))
return str(tmpdir)


def test_canary_tokenizer_build_special_tokenizer(tmp_path):
tokenizer = CanaryTokenizer.build_special_tokenizer(tmp_path)
expected_tokens = ["<unk>"] + SPECIAL_TOKENS + UNUSED_SPECIAL_TOKENS + ["▁"]
tokens = []
for i in range(tokenizer.tokenizer.vocab_size()):
tokens.append(tokenizer.tokenizer.IdToPiece(i))
assert expected_tokens == tokens


def test_canary_tokenizer_init_from_cfg(special_tokenizer_path, lang_tokenizer_path):
class DummyModel(ASRBPEMixin, Serialization):
pass

model = DummyModel()
model.register_artifact = Mock(side_effect=lambda self, x: x)
config = OmegaConf.create(
{
"type": "agg",
"dir": None,
"langs": {
"spl_tokens": {"dir": special_tokenizer_path, "type": "bpe"},
"en": {"dir": lang_tokenizer_path, "type": "bpe"},
},
"custom_tokenizer": {"_target_": "nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer",},
}
)
model._setup_aggregate_tokenizer(config)
tokenizer = model.tokenizer

assert isinstance(tokenizer, CanaryTokenizer)
assert len(tokenizer.tokenizers_dict) == 2
assert set(tokenizer.tokenizers_dict.keys()) == {"spl_tokens", "en"}

assert isinstance(tokenizer.tokenizers_dict["spl_tokens"], SentencePieceTokenizer)
assert tokenizer.tokenizers_dict["spl_tokens"].vocab_size == 32

assert isinstance(tokenizer.tokenizers_dict["en"], SentencePieceTokenizer)
assert tokenizer.tokenizers_dict["en"].vocab_size == 6

assert tokenizer.text_to_ids("<|startoftranscript|>", lang_id="spl_tokens") == [31, 3] # "_" comes first
assert tokenizer.text_to_ids("a", lang_id="en") == [32 + 1, 32 + 2]

0 comments on commit 9b7aa0f

Please sign in to comment.