Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlamaTokenizer: Slow implementation opts for whitespace-lead token (different from fast) #24569

Closed
1 of 2 tasks
lbeurerkellner opened this issue Jun 29, 2023 · 6 comments · Fixed by #24622
Closed
1 of 2 tasks
Assignees

Comments

@lbeurerkellner
Copy link

lbeurerkellner commented Jun 29, 2023

System Info

  • transformers version: 4.30.2
  • Platform: Linux-5.15.0-75-generic-x86_64-with-glibc2.31
  • Python version: 3.10.11
  • Huggingface_hub version: 0.14.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker @youn

Information

  • The official example scripts
  • My own modified scripts

Reproduction

Comparing slow and fast LlamaTokenizer instances with huggyllama/llama-7b.

from transformers import AutoTokenizer

model = "huggyllama/llama-7b"

fast = AutoTokenizer.from_pretrained(model)
slow = AutoTokenizer.from_pretrained(model, use_fast=False)

# use tokenize()
print(fast.tokenize("<s>uns"), slow.tokenize("<s>uns"))
# -> (['▁<s>', 'uns'], ['<s>', '▁uns'])

# use __call__
print(fast(f"{fast.bos_token}uns", add_special_tokens=False), slow(f"{slow.bos_token}uns", add_special_tokens=False))
# -> ({'input_ids': [1, 6948], 'token_type_ids': [0, 0], 'attention_mask': [1, 1]},
#     {'input_ids': [1, 9644], 'attention_mask': [1, 1]})

# round-tripping
print(fast.convert_tokens_to_string(fast.tokenize("<s>uns")), fast.convert_tokens_to_string(slow.tokenize("<s>uns")))
# -> ('<s>uns', '<s> uns')

Expected behavior

It looks like the slow LlamaTokenizer wrongly tokenises uns. I would not expect the additional whitespace when round-tripping or when tokenising in the first place.

Thanks a lot in advance.

@lbeurerkellner lbeurerkellner changed the title LlamaTokenizer: Fast and Slow implementations tokenize differently with LlamaTokenizer: Slow implementation opts for whitespace-lead token Jun 29, 2023
@lbeurerkellner lbeurerkellner changed the title LlamaTokenizer: Slow implementation opts for whitespace-lead token LlamaTokenizer: Slow implementation opts for whitespace-lead token (different from fast) Jun 29, 2023
@ArthurZucker
Copy link
Collaborator

Thanks for reporting, will have a look

@Bearnardd
Copy link
Contributor

Hi @ArthurZucker! Are you currently working on this? If not, I think I could fix it pretty quickly :)

@ArthurZucker
Copy link
Collaborator

Sure! Feel free to take it! 😉 I'll have a look soon otherwise

@Bearnardd
Copy link
Contributor

@ArthurZucker @lbeurerkellner I have done some debugging and I have a few observations. Firstly I have checked other tokenizers that use LlamaTokenizer or LlamaTokenizerFast and the results are pretty weird:

  1. the issue is not with uns but with any word after a special token like <s>. Why this is happening is pretty straightforward
# <s> is added to Trie so there is a split after its encounter in the text
tokens = self.tokens_trie.split(text) # tokenization_utils.py:517

So it seems like it was a deliberate decision to split special tokens like this?

  1. because of the above split, all slow tokenizers based on LLaMaTokenizer return ['<s>', '▁uns']

  2. more interesting thing is that most of the tokenizers based on LlamaTokenizerFast split text into ['▁<s>', 'uns'] (e.g fxmarty/tiny-llama-fast-tokenizer). But for example openlm-research/open_llama_3b which is one of the most downloaded llama based models outputs ['<s>', '▁uns'] even thought it has the same tokenizer config like the one from fxmarty.

LlamaTokenizerFast(name_or_path='openlm-research/open_llama_3b', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True)}, clean_up_tokenization_spaces=False)

@ArthurZucker
Copy link
Collaborator

the fast is working properly! As suspected, this is linked to #24622 and #24565. I am working on a fix for all our spm based models.

For other tokenizers, I wouldn’t refer to them since a lot are outdated/don’t include some fixes

@ArthurZucker
Copy link
Collaborator

Actually this is fixed, the output is now ['▁<s>', 'uns'] ['<s>', 'uns']. The fast just works that way for tokenization, but the output is the same. Use

slow = AutoTokenizer.from_pretrained(model, use_fast=False, legacy = False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants