Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query preprocessing #12

Closed
hannah348 opened this issue Jul 25, 2024 · 4 comments
Closed

Query preprocessing #12

hannah348 opened this issue Jul 25, 2024 · 4 comments

Comments

@hannah348
Copy link

I was using the process_queries method and I realized that the processor uses left padding. As a result the batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :] does cut of the padding tokens and if there were padding tokens in the sequence the end of the image sequence is the beginning of the query. Is that intentional? If so what is the reasoning behind that?

@ManuelFay
Copy link
Collaborator

Hello ! Thanks for the catch ! For PaliGemma, the tokenizer should be set with padding side = "right" ! I am pushing an update to force that behaviour and will push updated checkpoints, thanks for the catch !
In practice, as is with a mock image, it just introduces extra noise but should still work !

@hannah348
Copy link
Author

Does the model require that extra noise? If this was done during training as well the padding tokens might lead to degradation to performance as the model has less tokens available to represent the text, since the attention mask does not exclude the image tokens in colpali_engine/models/paligemma_colbert_architecture.py line 83

@ManuelFay
Copy link
Collaborator

ManuelFay commented Jul 29, 2024

In practice, the tokens corresponding to the input_ids that must be replaced by the image soft tokens, currently gets included in the query - since it is associated with nothing (and never trained cause replaced otherwise), this acts as a learned padding token (but with attention). It should not hurt performance particularly and might even act as extra "buffer tokens".

I am retraining checkpoints with the fix (since it happens during training as well) which I will release once I push the update, along with the benchmark results !

@ManuelFay
Copy link
Collaborator

So everythinh shoild be fixed !

Would be awesome if you can confirm using this new model:
https://huggingface.co/vidore/colpali-v1.1

and the code in branch: https://github.com/illuin-tech/colpali/tree/hard-negs

The base model version is fixed, and padding side is set to right, so issue should be fine @hannah348

ManuelFay added a commit that referenced this issue Aug 29, 2024
## [0.2.0] - 2024-08-29
 
Large refactoring to adress several issues and add features. This release is not backward compatible with previous versions.
The models trained under this version will exhibit degraded performance if used with the previous version of the code and vice versa.

[Branch](#23)
 

### Added
- Added multiple training options for training with hard negatives. This leads to better model performance !
- Added options for restarting training from a checkpoint.

### Changed

- Optionally load ColPali models from pre-initialized backbones of the same shape to remove any stochastic initialization when loading adapters. This fixes [11](#11) and [17](#17).
 
### Fixed
- Set padding side to right in the tokenizer to fix misalignement issue between different query lengths in the same batch. Fixes [12](#12)
- Add 10 extra pad token by default to the query to act as reasoning buffers. This enables the above fix to be made without degrading performance and cleans up the old technique of using <unused> tokens.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants