-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## [0.2.0] - 2024-08-29 Large refactoring to adress several issues and add features. This release is not backward compatible with previous versions. The models trained under this version will exhibit degraded performance if used with the previous version of the code and vice versa. [Branch](#23) ### Added - Added multiple training options for training with hard negatives. This leads to better model performance ! - Added options for restarting training from a checkpoint. ### Changed - Optionally load ColPali models from pre-initialized backbones of the same shape to remove any stochastic initialization when loading adapters. This fixes [11](#11) and [17](#17). ### Fixed - Set padding side to right in the tokenizer to fix misalignement issue between different query lengths in the same batch. Fixes [12](#12) - Add 10 extra pad token by default to the query to act as reasoning buffers. This enables the above fix to be made without degrading performance and cleans up the old technique of using <unused> tokens.
- Loading branch information
Showing
30 changed files
with
1,041 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
|
||
# Change Log | ||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](http://keepachangelog.com/) | ||
and this project adheres to [Semantic Versioning](http://semver.org/). | ||
|
||
|
||
## [0.2.0] - 2024-08-29 | ||
|
||
Large refactoring to adress several issues and add features. This release is not backward compatible with previous versions. | ||
The models trained under this version will exhibit degraded performance if used with the previous version of the code and vice versa. | ||
|
||
[Branch](https://github.com/illuin-tech/colpali/pull/23) | ||
|
||
|
||
### Added | ||
- Added multiple training options for training with hard negatives. This leads to better model performance ! | ||
- Added options for restarting training from a checkpoint. | ||
|
||
### Changed | ||
|
||
- Optionally load ColPali models from pre-initialized backbones of the same shape to remove any stochastic initialization when loading adapters. This fixes [11](https://github.com/illuin-tech/colpali/issues/11) and [17](https://github.com/illuin-tech/colpali/issues/17). | ||
|
||
### Fixed | ||
- Set padding side to right in the tokenizer to fix misalignement issue between different query lengths in the same batch. Fixes [12](https://github.com/illuin-tech/colpali/issues/12) | ||
- Add 10 extra pad token by default to the query to act as reasoning buffers. This enables the above fix to be made without degrading performance and cleans up the old technique of using <unused> tokens. | ||
|
||
## [0.1.1] - 2024-08-28 | ||
|
||
Minor patch release to fix packaging issues. | ||
|
||
### Fixed | ||
|
||
- [Branch](https://github.com/illuin-tech/colpali/commit/bd55e88c7af7069dde943f00665181fb94631cdd | ||
Fix .gitignore to include all necessary files in the package. | ||
|
||
## [0.1.0] - 2024-08-28 | ||
|
||
Initial code release corresponding to the paper. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from random import randint | ||
|
||
from datasets import Dataset, DatasetDict | ||
from transformers import PreTrainedTokenizer, ProcessorMixin | ||
|
||
from .custom_collator import CustomCollator | ||
|
||
|
||
class HardNegCollator(CustomCollator): | ||
def __init__( | ||
self, | ||
processor: ProcessorMixin = None, | ||
tokenizer: PreTrainedTokenizer = None, | ||
max_length: int = 2048, | ||
add_suffix: bool = True, | ||
image_dataset: Dataset = None, | ||
): | ||
super().__init__(processor, tokenizer, max_length, add_suffix) | ||
self.image_dataset = image_dataset | ||
assert self.image_dataset is not None, "image_dataset must be provided" | ||
|
||
def get_image_from_image_dataset(self, image_idx): | ||
return self.image_dataset[int(image_idx)]["image"] | ||
|
||
def __call__(self, examples): | ||
# assert len(examples) == 1, "HardNegCollator only supports a single example at at time" | ||
|
||
tmp_examples = examples | ||
examples = [] | ||
for example in tmp_examples: | ||
pos_image = self.get_image_from_image_dataset(example["gold_index"]) | ||
pos_query = example["query"] | ||
# randomly sample a negative image amongst the top 10 | ||
neg_image = self.get_image_from_image_dataset(example["negs"][randint(0, 9)]) | ||
examples += [{"image": pos_image, "query": pos_query, "neg_image": neg_image}] | ||
|
||
# reorder examples | ||
if self.processor is None: | ||
return self.forward_text(examples) | ||
if self.processor.__class__.__name__ == "Idefics2Processor": | ||
return self.forward_vision_idefics(examples) | ||
if self.processor.__class__.__name__ == "PaliGemmaProcessor": | ||
return self.forward_vision_pali(examples) | ||
if self.processor.__class__.__name__ == "SiglipProcessor": | ||
return self.forward_vision_siglip(examples) | ||
raise ValueError("Processor not supported") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from datasets import Dataset, DatasetDict | ||
from transformers import PreTrainedTokenizer, ProcessorMixin | ||
|
||
from .custom_collator import CustomCollator | ||
|
||
|
||
class HardNegCollator(CustomCollator): | ||
def __init__( | ||
self, | ||
processor: ProcessorMixin = None, | ||
tokenizer: PreTrainedTokenizer = None, | ||
max_length: int = 2048, | ||
add_suffix: bool = True, | ||
image_dataset: Dataset = None, | ||
): | ||
super().__init__(processor, tokenizer, max_length, add_suffix) | ||
self.image_dataset = image_dataset | ||
assert self.image_dataset is not None, "image_dataset must be provided" | ||
|
||
def get_image_from_docid(self, docid): | ||
example_idx, image_idx = docid.split("_") | ||
target_image = self.image_dataset[int(example_idx)]["images"][int(image_idx)] | ||
return target_image | ||
|
||
def __call__(self, examples): | ||
tmp_examples = examples | ||
examples = [] | ||
for example in tmp_examples: | ||
pos_image = self.get_image_from_docid(example["positive_passages"][0]["docid"]) | ||
pos_query = example["query"] | ||
neg_images_ids = [doc["docid"] for doc in example["negative_passages"][:1]] | ||
neg_images = [self.get_image_from_docid(docid) for docid in neg_images_ids] | ||
|
||
examples += [{"image": pos_image, "query": pos_query, "neg_image": neg_images[0]}] | ||
|
||
if self.processor is None: | ||
return self.forward_text(examples) | ||
if self.processor.__class__.__name__ == "Idefics2Processor": | ||
return self.forward_vision_idefics(examples) | ||
if self.processor.__class__.__name__ == "PaliGemmaProcessor": | ||
return self.forward_vision_pali(examples) | ||
if self.processor.__class__.__name__ == "SiglipProcessor": | ||
return self.forward_vision_siglip(examples) | ||
raise ValueError("Processor not supported") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.