Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SigLIP] Add fast tokenizer #29969

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d5d67b7
First draft
NielsRogge Mar 30, 2024
cbde88a
Fix more tests
NielsRogge Mar 31, 2024
de444e9
Add test
NielsRogge Mar 31, 2024
009fdc6
Remove print statements
NielsRogge Apr 1, 2024
f714af0
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Apr 22, 2024
6cd05c2
Address comments
NielsRogge Apr 22, 2024
d67e40f
Use regex
NielsRogge Apr 22, 2024
f576078
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Aug 22, 2024
de04050
Rebase
NielsRogge Aug 22, 2024
844c95c
Fix more tests
NielsRogge Mar 31, 2024
50500a5
remove strip in tokenize, keep characters used in special tokens, fix…
itazap Aug 23, 2024
8ba6e0b
ruff and FRAMEWORK error fix
itazap Aug 24, 2024
bf4f6db
remove unnecessary assertNotEqual from t5 (and siglip), add copied from)
itazap Aug 26, 2024
d850451
rm copied from
itazap Aug 26, 2024
e73fa01
typo
itazap Aug 26, 2024
cbe0a31
removing fast class
itazap Sep 24, 2024
d2b2339
updated tests for fast
itazap Sep 30, 2024
6379f9d
remove dev test file
Sep 30, 2024
6f55733
Merge branch 'main' into add_siglip_fast_tokenizer_bis
itazap Sep 30, 2024
05f8b5c
Update src/transformers/models/auto/tokenization_auto.py
ArthurZucker Oct 1, 2024
e296021
Update tests/models/llama/test_tokenization_llama.py
itazap Oct 1, 2024
0f9669b
Update src/transformers/models/siglip/__init__.py
itazap Oct 1, 2024
0bca141
Update src/transformers/models/siglip/__init__.py
itazap Oct 2, 2024
80d2f46
add auto test
Oct 2, 2024
2133809
fix test not to try importing Sigliptokenizerfast
Oct 2, 2024
3a0e825
import pretrained instead of siglip
Oct 3, 2024
e975d96
rm llama change
Oct 18, 2024
dc8ed15
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Oct 21, 2024
dfada5a
Make fixup
NielsRogge Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
allow to make our dependency on SentencePiece optional.
"""

import re
import string
import warnings
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -1085,6 +1087,44 @@ def post_processor(self):
)


class SiglipConverter(SpmConverter):
handle_byte_fallback = True

def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap

list_normalizers = []

if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
punctuation_to_remove = string.punctuation.replace(">", "").replace("<", "").replace("/", "")
list_normalizers.append(normalizers.Replace(Regex(r"[" + re.escape(punctuation_to_remove) + "]"), ""))
list_normalizers.extend(
[
normalizers.Replace(Regex(r"\s+"), " "),
normalizers.Strip(),
]
)

if not precompiled_charsmap:
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
else:
list_normalizers.extend(
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
)

return normalizers.Sequence(list_normalizers)

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)


class WhisperConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
Expand Down Expand Up @@ -1597,6 +1637,7 @@ def converted(self) -> Tokenizer:
"WhisperTokenizer": WhisperConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
"SiglipTokenizer": SiglipConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@

TOKENIZER_CLASSES = {
# Phi3 uses Llama tokenizer
name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
name: getattr(
transformers,
"LlamaTokenizerFast"
if name == "Phi3Tokenizer"
else ("PreTrainedTokenizerFast" if name == "SiglipTokenizer" else name + "Fast"),
)
for name in SLOW_TO_FAST_CONVERTERS
}

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,13 @@
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
),
),
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
(
"siglip",
(
"SiglipTokenizer" if is_sentencepiece_available() else None,
"PreTrainedTokenizerFast" if is_tokenizers_available() else None,
),
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
Expand Down
41 changes: 5 additions & 36 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand Down Expand Up @@ -921,6 +920,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down Expand Up @@ -1071,18 +1071,7 @@ def forward(

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits, labels, self.vocab_size)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -1186,27 +1175,8 @@ def forward(

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down Expand Up @@ -1289,8 +1259,7 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = self.loss_function(logits, labels, self.config)

if not return_dict:
output = (logits,) + outputs[2:]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/siglip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/siglip/tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def __setstate__(self, d):
self.sp_model.Load(self.vocab_file)

def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", string.punctuation))
punctuation_to_remove = string.punctuation.replace(">", "").replace("<", "").replace("/", "")
return text.translate(str.maketrans("", "", punctuation_to_remove))

# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
Expand All @@ -287,7 +288,6 @@ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
else:
text = self.remove_punctuation(text)
text = re.sub(r"\s+", " ", text)
text = text.strip()

return text

Expand Down
5 changes: 5 additions & 0 deletions tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def test_PreTrainedTokenizerFast_from_pretrained(self):
self.assertEqual(tokenizer.padding_side, "right")
self.assertEqual(tokenizer.truncation_side, "right")

def test_PreTrainedTokenizerFast_inferred(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker added this test for #33751 as Siglip would be the first model that can test this (#33751 (comment))

# Model does not have a fast tokenizer or PreTrainedTokenizerFast specified in config but can still load fast
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224", use_fast=True)
self.assertEqual(type(tokenizer), PreTrainedTokenizerFast)

def test_auto_tokenizer_from_local_folder(self):
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
Expand Down
46 changes: 42 additions & 4 deletions tests/models/siglip/test_tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
import tempfile
import unittest

from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer
from transformers import (
SPIECE_UNDERLINE,
AddedToken,
BatchEncoding,
PreTrainedTokenizerFast,
SiglipTokenizer,
)
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available, is_torch_available

Expand All @@ -40,6 +46,7 @@
class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/siglip-base-patch16-224"
tokenizer_class = SiglipTokenizer
rust_tokenizer_class = PreTrainedTokenizerFast
test_rust_tokenizer = False
test_sentencepiece = True
test_sentencepiece_ignore_case = True
Expand All @@ -62,7 +69,9 @@ def test_convert_token_and_id(self):
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)

def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
tokenizer = SiglipTokenizer(SAMPLE_VOCAB)

vocab_keys = list(tokenizer.get_vocab().keys())

self.assertEqual(vocab_keys[0], "<unk>")
self.assertEqual(vocab_keys[1], "<s>")
Expand Down Expand Up @@ -139,6 +148,9 @@ def siglip_tokenizer(self):
def get_tokenizer(self, **kwargs) -> SiglipTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)

def get_rust_tokenizer(self, **kwargs):
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_rust_and_python_full_tokenizers with T5->Siglip
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
Expand Down Expand Up @@ -221,6 +233,14 @@ def test_subword_regularization_tokenizer(self):
def test_pickle_subword_regularization_tokenizer(self):
pass

@unittest.skip(reason="SiglipTokenizer has custom lowercase logic")
def test_added_tokens_do_lower_case(self):
pass

@unittest.skip(reason="Sigliptokenizer strips the punctuation for chat tokens")
def test_chat_template_return_assistant_tokens_mask(self):
pass

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization with T5->Siglip
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
Expand Down Expand Up @@ -378,8 +398,7 @@ def test_some_edge_cases(self):
sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["</", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>"])
self.assertEqual(tokens, ["</s>", ">"])

tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
Expand All @@ -397,6 +416,25 @@ def test_some_edge_cases(self):
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str))

def test_compare_prepare_for_model(self):
if not self.test_slow_tokenizer:
# as we don't have a slow version, we can't compare the outputs between slow and fast versions
self.skipTest(reason="test_slow_tokenizer is set to False")

for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
string_sequence = "Asserting that both tokenizers are equal"
python_output = tokenizer_p.prepare_for_model(
tokenizer_p.encode(string_sequence, add_special_tokens=False), add_special_tokens=False
)
rust_output = tokenizer_r.prepare_for_model(
tokenizer_r.encode(string_sequence, add_special_tokens=False), add_special_tokens=False
)
for key in python_output:
self.assertEqual(python_output[key], rust_output[key])


@require_sentencepiece
@require_tokenizers
Expand Down
1 change: 0 additions & 1 deletion tests/models/t5/test_tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def test_some_edge_cases(self):
sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["<", "/", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>", ">"])

tokens = tokenizer.tokenize("")
Expand Down
Loading
Loading