Skip to content

Commit

Permalink
Support Qwen2VL (#56)
Browse files Browse the repository at this point in the history
* add: qwen2

* fix: add qwen2 paths

* fix: hidden size

* fix: image

* fix: use cond gen model

* fix: image_grid

* debug: breakpoint

* debug: breakpoint

* debug: breakpoint

* cleanup

* cleanup

* fix: cleanup

* fix: todo

* simplify negative trainer

* see long run

* revert

* test: resize

* fix: resize

* debug

* debug

* simplify scrip

* enable adapter resume

* test

* pixels

* fix

* fix

* max l

* process

* infrence mode

* training

* debug

* frce rope scalng

* debug

* debug

* single

* multo

* multiple of 32

* back to normal values

* add hardnegs

* bug fic

* bug fix

* module 32

* stop trian

* no fix reso

* debug

* fix

* debug inouts

* input keys

* comment out

* genius

* fix

* fix

* fix

* print

* lesgo

* move logic to processor

* ff

* debug

* coms

* update qwen

* 512

* lint

* up resolution

* remove inputs for gen

* remove useless dep

* Changelog

* Update pyproject.toml

Co-authored-by: Tony Wu <28306721+tonywu71@users.noreply.github.com>

* Update colpali_engine/utils/dataset_transformation.py

Co-authored-by: Tony Wu <28306721+tonywu71@users.noreply.github.com>

* Update colpali_engine/trainer/colmodel_training.py

Co-authored-by: Tony Wu <28306721+tonywu71@users.noreply.github.com>

* Update colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py

Co-authored-by: Tony Wu <28306721+tonywu71@users.noreply.github.com>

* remove annying warning

* add comment

* prepare deprecation of is_vision_model

* deprecation of useless trainer class

---------

Co-authored-by: Tony Wu <28306721+tonywu71@users.noreply.github.com>
  • Loading branch information
ManuelFay and tonywu71 authored Sep 27, 2024
1 parent 59fbda2 commit 16c8bb3
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 188 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Add module-level imports for collators
- Add sanity check in the run inference example script
- Add E2E test for ColPali
- Add Qwen2-VL support

### Changed

- Improve code clarity the run inference example script
- Subset the example dataset in the run inference example script
- Rename scorer test to `test_processing_utils`
- Greatly simplify routing logic in Trainer selection and when feeding arguments to the model forward pass (refacto)
- Removed class `ContrastiveNegativeTrainer` which is now just integrated in ContrastiveTrainer. This should not affect the user-facing API.
- Bumped transformers version to 4.45.0 to get Qwen2-VL support

### Fixed

- Import HardNegCollator at module-level if and only if datasets is available
- Remove the need for `typer` in the run inference example script
- Fix edge case when empty suffix `""` given to processor
- Fix bug in HardNegCollator since 0.3.0

## [0.3.0] - 2024-09-10

Expand Down
2 changes: 1 addition & 1 deletion colpali_engine/collators/hard_neg_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:

examples += [{"image": pos_image, "query": pos_query, "neg_image": neg_image}]

return self(examples)
return super().__call__(examples)
11 changes: 7 additions & 4 deletions colpali_engine/collators/visual_retriever_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from PIL.Image import Image

from colpali_engine.models.idefics_2 import ColIdefics2Processor
from colpali_engine.models.paligemma import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor

Expand All @@ -22,9 +23,10 @@ def __init__(
self.max_length = max_length
self.suffix = ""

self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
self.processor.tokenizer.additional_special_tokens.index("<image>")
]
if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor):
self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
self.processor.tokenizer.additional_special_tokens.index("<image>")
]

if isinstance(self.processor, ColPaliProcessor):
if self.processor.tokenizer.padding_side != "right":
Expand Down Expand Up @@ -76,7 +78,8 @@ def __call__(
batch_query = None

if all([t is None for t in texts_query]):
print("All queries are `None`. Returning `None` for all queries.")
# print("All queries are `None`. Returning `None` for all queries.")
pass
elif any([t is None for t in texts_query]):
# If it's the first query that is not None but the rest are None, then it's hard negatives.
raise ValueError("Some queries are None. This collator does not support None queries yet.")
Expand Down
2 changes: 2 additions & 0 deletions colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .idefics_2 import BiIdefics2, ColIdefics2, ColIdefics2Processor
from .paligemma import BiPali, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import ColQwen2, ColQwen2Processor

1 change: 1 addition & 0 deletions colpali_engine/models/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .colqwen2 import ColQwen2, ColQwen2Processor
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen2/colqwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_colqwen2 import ColQwen2
from .processing_colqwen2 import ColQwen2Processor
55 changes: 55 additions & 0 deletions colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import ClassVar

import torch
from torch import nn
from transformers.models.qwen2_vl import Qwen2VLConfig, Qwen2VLForConditionalGeneration


class ColQwen2(Qwen2VLForConditionalGeneration):
"""
ColQwen2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(self, config: Qwen2VLConfig):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim)
self.padding_side = "left"
self.post_init()

def forward(self, *args, **kwargs) -> torch.Tensor:
# Delete output_hidden_states from kwargs
kwargs.pop("output_hidden_states", None)


# The following code is a hack to make sure the scatter in DDP is done correctly when training on multiple GPUs
if "pixel_values" in kwargs:
# compute pixel_values offsets
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)

position_ids, rope_deltas = self.get_rope_index(
input_ids=kwargs["input_ids"],
image_grid_thw=kwargs["image_grid_thw"],
video_grid_thw=None,
attention_mask=kwargs["attention_mask"],
)
outputs = super().forward(*args,
**kwargs,
position_ids=position_ids,
rope_deltas=rope_deltas,
use_cache=False,
output_hidden_states=True) # (batch_size, sequence_length, hidden_size)

# inputs = self.prepare_inputs_for_generation(*args, **kwargs, use_cache=False)
# outputs = super().forward(*args, **kwargs, output_hidden_states=True)

last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)
return proj
150 changes: 150 additions & 0 deletions colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import math
from typing import List, Optional, Union

import torch
from PIL import Image
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLProcessor

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor


class ColQwen2Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
"""
Processor for ColQwen2.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer.padding_side = "left"
self.min_pixels = 4 * 28 * 28
self.max_pixels = 768 * 28 * 28
self.factor = 28
self.max_ratio = 200

@staticmethod
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor

@staticmethod
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor

@staticmethod
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor


def smart_resize(self, height: int, width: int, factor: int, min_pixels: int, max_pixels: int) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > self.max_ratio:
raise ValueError(
f"absolute aspect ratio must be smaller than {self.max_ratio}, "
f"got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, self.round_by_factor(height, factor))
w_bar = max(factor, self.round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = self.floor_by_factor(height / beta, factor)
w_bar = self.floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = self.ceil_by_factor(height * beta, factor)
w_bar = self.ceil_by_factor(width * beta, factor)
return h_bar, w_bar

def process_images(
self,
images: List[Image.Image],
) -> BatchFeature:
"""
Process images for ColPali.
"""
texts_doc = (["<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n"]
* len(images))

def resize_and_convert(image: Image.Image) -> Image.Image:
image_size = image.size
resized_height, resized_width = self.smart_resize(image_size[1],
image_size[0],
factor=self.factor,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels)
# print(f"Resizing image from {image_size} to {(resized_height, resized_width)}")
return image.convert("RGB").resize((resized_width, resized_height))

images = [resize_and_convert(image) for image in images]


batch_doc = self(
text=texts_doc,
images=images,
padding="longest",
return_tensors="pt"
)


# The following code is a hack to make sure the scatter in DDP is done correctly when training on multiple GPUs
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]
# separate pixel_values for each image
pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist())
# pad pixel_values to the same length to be able to make it into a tensor
max_length = max([len(pv) for pv in pixel_values])
pixel_values = [torch.cat([pv,
torch.zeros((max_length - len(pv), pv.shape[1]),
dtype=pv.dtype, device=pv.device)]) for pv in pixel_values]
batch_doc["pixel_values"] = torch.stack(pixel_values)


return batch_doc

def process_queries(
self,
queries: List[str],
max_length: int = 50,
suffix: Optional[str] = None,
) -> BatchFeature:
"""
Process queries for ColPali.
"""
if suffix is None:
suffix = "<pad>" * 10
texts_query: List[str] = []

for query in queries:
query = f"Query: {query}"
query += suffix # add suffix (pad tokens)
texts_query.append(query)

batch_query = self(
text=texts_query,
return_tensors="pt",
padding="longest",
# max_length=max_length + self.image_seq_length,
)

return batch_query

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
return self.score_multi_vector(qs, ps, device=device, **kwargs)
66 changes: 15 additions & 51 deletions colpali_engine/trainer/colmodel_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from colpali_engine.loss.late_interaction_losses import (
ColbertLoss,
)
from colpali_engine.trainer.contrastive_trainer import ContrastiveNegativeTrainer, ContrastiveTrainer
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator
from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
Expand Down Expand Up @@ -67,6 +67,7 @@ def __post_init__(self):

if self.pretrained_peft_model_name_or_path is not None:
self.model.load_adapter(self.pretrained_peft_model_name_or_path)

print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}")

if self.peft_config is not None:
Expand All @@ -78,12 +79,6 @@ def __post_init__(self):
self.model = get_peft_model(self.model, self.peft_config)
self.model.print_trainable_parameters()
else:
# Ugly debugging hack
# if self.model.model.config.text_config.vocab_size == 32000:
# print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!")
# self.model.model.text_model.resize_token_embeddings(32003)
# self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2]
# self.model.enable_input_require_grads()
if self.pretrained_peft_model_name_or_path is None:
# self.model.add_adapter(self.peft_config)
# self.model.enable_adapters()
Expand All @@ -105,7 +100,6 @@ def __init__(self, config: ColModelTrainingConfig) -> None:
self.dataset = self.dataset[0]
self.collator = HardNegCollator(
processor=self.config.processor,
tokenizer=self.config.tokenizer,
max_length=self.config.max_length,
image_dataset=neg_dataset,
)
Expand All @@ -120,26 +114,19 @@ def __init__(self, config: ColModelTrainingConfig) -> None:
def train(self) -> None:
if isinstance(self.collator, HardNegCollator):
print("Training with hard negatives")
trainer = ContrastiveNegativeTrainer(
model=self.model,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
args=self.config.tr_args,
data_collator=self.collator,
loss_func=self.config.loss_func,
is_vision_model=self.config.processor is not None,
)
else:
print("Training with in-batch negatives")
trainer = ContrastiveTrainer(
model=self.model,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
args=self.config.tr_args,
data_collator=self.collator,
loss_func=self.config.loss_func,
is_vision_model=self.config.processor is not None,
)

trainer = ContrastiveTrainer(
model=self.model,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
args=self.config.tr_args,
data_collator=self.collator,
loss_func=self.config.loss_func,
is_vision_model=self.config.processor is not None,
)

trainer.args.remove_unused_columns = False

result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
Expand All @@ -148,10 +135,6 @@ def train(self) -> None:
def eval_dataset(self, test_dataset):
self.model.eval()

# # debug
# if len(test_dataset) > 200:
# test_dataset = test_dataset.select(range(0, 100))

idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None]
idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None]

Expand Down Expand Up @@ -191,27 +174,8 @@ def eval_dataset(self, test_dataset):
with torch.no_grad():
for dataloader in [dataloader_with_query, dataloader_without_query]:
for batch in tqdm(dataloader):
if "doc_pixel_values" not in batch:
doc = self.model(
input_ids=batch["doc_input_ids"].to(device),
attention_mask=batch["doc_attention_mask"].to(device),
)

else:
if "doc_pixel_attention_mask" in batch:
doc = self.model(
input_ids=batch["doc_input_ids"].to(device),
attention_mask=batch["doc_attention_mask"].to(device),
pixel_values=batch["doc_pixel_values"].to(device),
pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device),
)
else:
doc = self.model(
input_ids=batch["doc_input_ids"].to(device),
attention_mask=batch["doc_attention_mask"].to(device),
pixel_values=batch["doc_pixel_values"].to(device),
)

# feed only kwargs with 'doc_' prefix
doc = self.model(**{k[4:]: v.to(device) for k, v in batch.items() if k.startswith("doc")})
ps.extend(list(torch.unbind(doc.to("cpu"))))

if "query_input_ids" in batch:
Expand Down
Loading

0 comments on commit 16c8bb3

Please sign in to comment.