From 396bc0e18097a4eb8b17f284864d93db9b0ffa70 Mon Sep 17 00:00:00 2001 From: Merve Noyan Date: Mon, 21 Oct 2024 13:18:26 +0200 Subject: [PATCH] changes without validating logits --- src/transformers/models/siglip/convert_siglip_to_hf.py | 9 ++++++--- src/transformers/models/siglip/modeling_siglip.py | 7 ++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/siglip/convert_siglip_to_hf.py b/src/transformers/models/siglip/convert_siglip_to_hf.py index 5e6102e32a9ba0..1d08701d5ed8a0 100644 --- a/src/transformers/models/siglip/convert_siglip_to_hf.py +++ b/src/transformers/models/siglip/convert_siglip_to_hf.py @@ -100,6 +100,8 @@ def get_siglip_config(model_name): config.vision_config.intermediate_size = 4304 config.vision_config.num_hidden_layers = 27 config.vision_config.num_attention_heads = 16 + if (config.vision_config.image_size==256 and config.text_config.vocab_size==250000 and config.vision_config.patch_size==16): + config.text_config.no_head = True else: raise ValueError("Model not supported") @@ -112,8 +114,7 @@ def get_siglip_config(model_name): def create_rename_keys(config): rename_keys = [] # fmt: off - if (config.vision_config.image_size==256 and config.text_config.vocab_size==250000 and config.vision_config.patch_size==16): - siglip_sovit_i18_256 = True + # vision encoder rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight")) @@ -176,7 +177,8 @@ def create_rename_keys(config): rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight")) rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias")) - if not siglip_sovit_i18_256: + + if not config.text_config.no_head: rename_keys.append(("params/txt/head/kernel", "text_model.head.weight")) rename_keys.append(("params/txt/head/bias", "text_model.head.bias")) @@ -308,6 +310,7 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit read_in_q_k_v_head(state_dict, config) # load HuggingFace model model = SiglipModel(config).eval() + print("config.text_config", config.text_config) model.load_state_dict(state_dict) # create processor # important: make tokenizer not return attention_mask since original one doesn't require it diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index b18444adc26963..c9011bc1c85023 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -926,8 +926,8 @@ def __init__(self, config: SiglipTextConfig): self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - #if not (config.vision_config.image_size==256 and config.text_config.vocab_size==250000 and config.vision_config.patch_size==16): - self.head = nn.Linear(embed_dim, embed_dim) + if not hasattr(self.config, "no_head"): + self.head = nn.Linear(embed_dim, embed_dim) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @@ -978,7 +978,8 @@ def forward( # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] - pooled_output = self.head(pooled_output) + if not hasattr(self.config, "no_head"): + pooled_output = self.head(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:]