Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SigLIP] Add fast tokenizer #29969

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d5d67b7
First draft
NielsRogge Mar 30, 2024
cbde88a
Fix more tests
NielsRogge Mar 31, 2024
de444e9
Add test
NielsRogge Mar 31, 2024
009fdc6
Remove print statements
NielsRogge Apr 1, 2024
f714af0
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Apr 22, 2024
6cd05c2
Address comments
NielsRogge Apr 22, 2024
d67e40f
Use regex
NielsRogge Apr 22, 2024
f576078
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Aug 22, 2024
de04050
Rebase
NielsRogge Aug 22, 2024
844c95c
Fix more tests
NielsRogge Mar 31, 2024
50500a5
remove strip in tokenize, keep characters used in special tokens, fix…
itazap Aug 23, 2024
8ba6e0b
ruff and FRAMEWORK error fix
itazap Aug 24, 2024
bf4f6db
remove unnecessary assertNotEqual from t5 (and siglip), add copied from)
itazap Aug 26, 2024
d850451
rm copied from
itazap Aug 26, 2024
e73fa01
typo
itazap Aug 26, 2024
cbe0a31
removing fast class
itazap Sep 24, 2024
d2b2339
updated tests for fast
itazap Sep 30, 2024
6379f9d
remove dev test file
Sep 30, 2024
6f55733
Merge branch 'main' into add_siglip_fast_tokenizer_bis
itazap Sep 30, 2024
05f8b5c
Update src/transformers/models/auto/tokenization_auto.py
ArthurZucker Oct 1, 2024
e296021
Update tests/models/llama/test_tokenization_llama.py
itazap Oct 1, 2024
0f9669b
Update src/transformers/models/siglip/__init__.py
itazap Oct 1, 2024
0bca141
Update src/transformers/models/siglip/__init__.py
itazap Oct 2, 2024
80d2f46
add auto test
Oct 2, 2024
2133809
fix test not to try importing Sigliptokenizerfast
Oct 2, 2024
3a0e825
import pretrained instead of siglip
Oct 3, 2024
e975d96
rm llama change
Oct 18, 2024
dc8ed15
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Oct 21, 2024
dfada5a
Make fixup
NielsRogge Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
allow to make our dependency on SentencePiece optional.
"""

import re
import string
import warnings
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -1086,6 +1088,44 @@ def post_processor(self):
)


class SiglipConverter(SpmConverter):
handle_byte_fallback = True

def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap

list_normalizers = []

if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
punctuation_to_remove = string.punctuation.replace(">", "").replace("<", "").replace("/", "")
list_normalizers.append(normalizers.Replace(Regex(r"[" + re.escape(punctuation_to_remove) + "]"), ""))
list_normalizers.extend(
[
normalizers.Replace(Regex(r"\s+"), " "),
normalizers.Strip(),
]
)

if not precompiled_charsmap:
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
else:
list_normalizers.extend(
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
)

return normalizers.Sequence(list_normalizers)

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)


class WhisperConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
Expand Down Expand Up @@ -1557,6 +1597,7 @@ def converted(self) -> Tokenizer:
"WhisperTokenizer": WhisperConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
"SiglipTokenizer": SiglipConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,13 @@
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
),
),
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
(
"siglip",
(
"SiglipTokenizer" if is_sentencepiece_available() else None,
"PreTrainedTokenizerFast" if is_tokenizers_available() else None,
),
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/siglip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/siglip/tokenization_siglip.py
itazap marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def __setstate__(self, d):
self.sp_model.Load(self.vocab_file)

def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", string.punctuation))
punctuation_to_remove = string.punctuation.replace(">", "").replace("<", "").replace("/", "")
return text.translate(str.maketrans("", "", punctuation_to_remove))

# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
Expand All @@ -287,7 +288,6 @@ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
else:
text = self.remove_punctuation(text)
text = re.sub(r"\s+", " ", text)
text = text.strip()

return text

Expand Down
5 changes: 5 additions & 0 deletions tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def test_PreTrainedTokenizerFast_from_pretrained(self):
self.assertEqual(tokenizer.padding_side, "right")
self.assertEqual(tokenizer.truncation_side, "right")

def test_PreTrainedTokenizerFast_inferred(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker added this test for #33751 as Siglip would be the first model that can test this (#33751 (comment))

# Model does not have a fast tokenizer or PreTrainedTokenizerFast specified in config but can still load fast
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224", use_fast=True)
self.assertEqual(type(tokenizer), PreTrainedTokenizerFast)

def test_auto_tokenizer_from_local_folder(self):
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
Expand Down
6 changes: 3 additions & 3 deletions tests/models/llama/test_tokenization_llama.py
itazap marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
AutoTokenizer,
LlamaTokenizer,
LlamaTokenizerFast,
PreTrainedTokenizerFast,
PreTrainedTokenizerFast
)
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers.testing_utils import (
Expand All @@ -54,9 +54,9 @@
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = ["hf-internal-testing/llama-tokenizer", "meta-llama/Llama-2-7b-hf"]
tokenizer_class = LlamaTokenizer
rust_tokenizer_class = LlamaTokenizerFast
rust_tokenizer_class = PreTrainedTokenizerFast

test_rust_tokenizer = False
test_rust_tokenizer = True
test_sentencepiece = True
from_pretrained_kwargs = {}

Expand Down
66 changes: 55 additions & 11 deletions tests/models/siglip/test_tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
import tempfile
import unittest

from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer
from transformers import (
SPIECE_UNDERLINE,
AddedToken,
AutoTokenizer,
BatchEncoding,
PreTrainedTokenizerFast,
SiglipTokenizer,
)
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available, is_torch_available

Expand All @@ -40,6 +47,7 @@
class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/siglip-base-patch16-224"
tokenizer_class = SiglipTokenizer
rust_tokenizer_class = PreTrainedTokenizerFast
test_rust_tokenizer = False
test_sentencepiece = True
test_sentencepiece_ignore_case = True
Expand All @@ -49,7 +57,7 @@ def setUp(self):
super().setUp()

# We have a SentencePiece fixture for testing
tokenizer = SiglipTokenizer(SAMPLE_VOCAB)
tokenizer = SiglipTokenizer.from_pretrained(self.from_pretrained_id[0])
tokenizer.save_pretrained(self.tmpdirname)

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_convert_token_and_id with T5->Siglip
Expand All @@ -58,11 +66,15 @@ def test_convert_token_and_id(self):
token = "<s>"
token_id = 1

self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
tokenizer = SiglipTokenizer(SAMPLE_VOCAB)

self.assertEqual(tokenizer._convert_token_to_id(token), token_id)
self.assertEqual(tokenizer._convert_id_to_token(token_id), token)

def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
tokenizer = SiglipTokenizer(SAMPLE_VOCAB)

vocab_keys = list(tokenizer.get_vocab().keys())

self.assertEqual(vocab_keys[0], "<unk>")
self.assertEqual(vocab_keys[1], "<s>")
Expand Down Expand Up @@ -137,15 +149,19 @@ def siglip_tokenizer(self):

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.get_tokenizer with T5->Siglip
def get_tokenizer(self, **kwargs) -> SiglipTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
return SiglipTokenizer.from_pretrained(self.from_pretrained_id[0], **kwargs)

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.get_rust_tokenizer with T5->Siglip
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
return AutoTokenizer.from_pretrained(self.from_pretrained_id[0], use_fast=True, **kwargs)

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_rust_and_python_full_tokenizers with T5->Siglip
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
self.skipTest(reason="test_rust_tokenizer is set to False")

tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id[0])
rust_tokenizer = AutoTokenizer.from_pretrained(self.from_pretrained_id[0])

sequence = "I was born in 92000, and this is falsé."

Expand Down Expand Up @@ -221,6 +237,14 @@ def test_subword_regularization_tokenizer(self):
def test_pickle_subword_regularization_tokenizer(self):
pass

@unittest.skip(reason="SiglipTokenizer has custom lowercase logic")
def test_added_tokens_do_lower_case(self):
pass

@unittest.skip(reason="Sigliptokenizer strips the punctuation for chat tokens")
def test_chat_template_return_assistant_tokens_mask(self):
pass

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization with T5->Siglip
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
Expand All @@ -231,7 +255,9 @@ def test_special_tokens_initialization(self):
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
pretrained_name,
additional_special_tokens=added_tokens,
**kwargs,
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
Expand Down Expand Up @@ -378,8 +404,7 @@ def test_some_edge_cases(self):
sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["</", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>"])
self.assertEqual(tokens, ["</s>", ">"])

tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
Expand All @@ -397,6 +422,25 @@ def test_some_edge_cases(self):
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str))

def test_compare_prepare_for_model(self):
if not self.test_slow_tokenizer:
# as we don't have a slow version, we can't compare the outputs between slow and fast versions
self.skipTest(reason="test_slow_tokenizer is set to False")

for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
string_sequence = "Asserting that both tokenizers are equal"
python_output = tokenizer_p.prepare_for_model(
tokenizer_p.encode(string_sequence, add_special_tokens=False), add_special_tokens=False
)
rust_output = tokenizer_r.prepare_for_model(
tokenizer_r.encode(string_sequence, add_special_tokens=False), add_special_tokens=False
)
for key in python_output:
self.assertEqual(python_output[key], rust_output[key])


@require_sentencepiece
@require_tokenizers
Expand Down
1 change: 0 additions & 1 deletion tests/models/t5/test_tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def test_some_edge_cases(self):
sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["<", "/", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>", ">"])

tokens = tokenizer.tokenize("")
Expand Down
39 changes: 30 additions & 9 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,13 @@ def assert_batch_padded_input_match(
for i_r, i_p in zip(input_r[model_main_input_name], input_p[model_main_input_name]):
self.assert_padded_input_match(i_r, i_p, max_length, pad_token_id)

for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
self.assertSequenceEqual(i_r, i_p)
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)

if "attention_mask" in tokenizer_r.model_input_names:
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
self.assertSequenceEqual(i_r, i_p)

@staticmethod
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
Expand Down Expand Up @@ -543,6 +548,15 @@ def test_model_input_names_signature(self):
# to make sure `tokenizer.pad(...)` works correctly
self.assertTrue(tokenizer.model_input_names[0] in accepted_model_main_input_names)

def test_model_input_names_python_rust_equals(self):
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
if not self.test_rust_tokenizer:
return

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
self.assertListEqual(tokenizer.model_input_names, tokenizer.model_input_names)

def test_rust_tokenizer_signature(self):
if not self.test_rust_tokenizer:
self.skipTest(reason="test_rust_tokenizer is set to False")
Expand Down Expand Up @@ -2973,7 +2987,8 @@ def test_prepare_seq2seq_batch(self):
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
)
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
if "attention_mask" in tokenizer.model_input_names:
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
self.assertNotIn("decoder_input_ids", batch_encoder_only)

def test_is_fast(self):
Expand Down Expand Up @@ -3642,23 +3657,26 @@ def test_padding(self, max_length=50):
"This is a simple input", max_length=max_length, pad_to_max_length=True
)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
if "attention_mask" in tokenizer.model_input_names:
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus(
"This is a simple input", max_length=max_length, padding="max_length"
)
input_p = tokenizer_p.encode_plus(
"This is a simple input", max_length=max_length, padding="max_length"
)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
if "attention_mask" in tokenizer.model_input_names:
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])

input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
self.assert_padded_input_match(
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
)

self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
if "attention_mask" in tokenizer.model_input_names:
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])

# Encode_plus - Pair input
input_r = tokenizer_r.encode_plus(
Expand All @@ -3668,21 +3686,24 @@ def test_padding(self, max_length=50):
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
if "attention_mask" in tokenizer.model_input_names:
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
)
input_p = tokenizer_p.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
if "attention_mask" in tokenizer.model_input_names:
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
self.assert_padded_input_match(
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
if "attention_mask" in tokenizer.model_input_names:
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])

# Batch_encode_plus - Simple input
input_r = tokenizer_r.batch_encode_plus(
Expand Down
Loading