Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DAB-DETR Object detection/segmentation model #30803

Merged
merged 116 commits into from
Feb 4, 2025

Conversation

conditionedstimulus
Copy link
Contributor

What does this PR do?

Add DAB-DETR Object detection model. Paper: https://arxiv.org/abs/2201.12329
Original code repo: https://github.com/IDEA-Research/DAB-DETR

Fixes # (issue)
[WIP] This model is part of how DETR models have evolved, alongside DN DETR (not part of this PR), to pave the way for newer and better models like Dino and Stable Dino in object detection

Who can review?

@amyeroberts

@amyeroberts
Copy link
Collaborator

Hi @conditionedstimulus, thanks for opening a PR!

Just skimming over the modeling files, it looks like all of the modules are copied from, or can be copied from conditional DETR. Are there any architectural changes this model brings? If not, then all we need to do is convert the checkpoints and upload those to the hub such that they can be loaded in ConditionalDETR directly

@conditionedstimulus
Copy link
Contributor Author

Hi @conditionedstimulus, thanks for opening a PR!

Just skimming over the modeling files, it looks like all of the modules are copied from, or can be copied from conditional DETR. Are there any architectural changes this model brings? If not, then all we need to do is convert the checkpoints and upload those to the hub such that they can be loaded in ConditionalDETR directly

Hi Amy,

I attached a photo comparing the cross-attention of the decoder in DETR, Conditional DETR, and DAB DETR, as this is the main architectural difference. I copied the code from Conditional DETR because this model is an extension/evolved version of Conditional DETR. I believe it would be cool and useful to include this model in the HF object detection collection.
Screenshot 2024-05-15 at 22 25 15

@amyeroberts
Copy link
Collaborator

@conditionedstimulus Thanks for sharing! OK, seems useful to have this available as an option as part of the DETR family in the library. Feel free to ping me when the PR is ready for review.

cc @qubvel for reference

@qubvel
Copy link
Member

qubvel commented Jan 31, 2025

Noticed we don't have approval from @ArthurZucker, waiting for his review

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few super small comments! Thanks for your patience! 🤗

Comment on lines +955 to +956
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no what I mean is we should only create n, k) for n, k in zip([input_dim] + h, h + [output_dim] in the config. then you know exactly in and out that should be used for the linear layers.

@conditionedstimulus
Copy link
Contributor Author

Hi @ArthurZucker and @qubvel,

I’ve made most of the required modifications. Where I didn’t, I left comments on your feedback.
I also updated the test file where needed and added some additional information to the model card markdown file.

Thanks!

@qubvel
Copy link
Member

qubvel commented Feb 3, 2025

@conditionedstimulus Thanks for the updates! Please update converted weights for other checkpoints on the Hub as well and I will ask for transfer

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go! 🚀

Comment on lines 1245 to 1247
hidden_states = self.layernorm(hidden_states)
intermediate.pop()
intermediate.append(hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intermediate_state = self.layernorm(hidden_states)
intermediate.append(intermediate_states)
`
vs
`intermediate.append(self.layernorm(hidden_states))`

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will avoid this ugly pop append

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the list manipulation entirely. I didn’t revisit the original code, but as I recall, this was part of a conditional section. Since we removed many configurations, the list manipulation remained unchanged—popping the last element and appending the same value back. So, I only kept the hidden states layer normalization.

@conditionedstimulus
Copy link
Contributor Author

Hi @ArthurZucker and @qubvel,

I’ve finalized the last modification—if I understand correctly, this should be the final version, and we’ll roll it out soon.
I also updated the converted weights, so I believe it’s ready to be moved under the new organization.
I merged main too and ofc:

SKIPPED [1] tests/generation/test_utils.py:1458: The decoder-only derived from encoder-decoder models are not expected to support left-padding.
FAILED tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py::Qwen2_5_VLModelTest::test_prompt_lookup_decoding_matches_greedy_search - IndexError: index 41 is out of bound

Thanks, for your review, guidance, and support! :)

Looking forward to the merge! 🤗

@qubvel
Copy link
Member

qubvel commented Feb 3, 2025

run-slow: dab_detr

Copy link

github-actions bot commented Feb 3, 2025

This comment contains run-slow, running the specified jobs: ['models/dab_detr'] ...

@qubvel
Copy link
Member

qubvel commented Feb 4, 2025

run-slow: dab_detr

Copy link

github-actions bot commented Feb 4, 2025

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/dab_detr']
quantizations: [] ...

@ydshieh
Copy link
Collaborator

ydshieh commented Feb 4, 2025

run-slow: dab_detr

Copy link

github-actions bot commented Feb 4, 2025

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/dab_detr']
quantizations: [] ...

@qubvel
Copy link
Member

qubvel commented Feb 4, 2025

run-slow: dab_detr

Copy link

github-actions bot commented Feb 4, 2025

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/dab_detr']
quantizations: [] ...

@qubvel qubvel merged commit 8d73a38 into huggingface:main Feb 4, 2025
26 checks passed
@qubvel
Copy link
Member

qubvel commented Feb 4, 2025

@conditionedstimulus Congratulations on merging the model! 🎉 It was a long journey, and we really appreciate you were able to finish it 💪 . Thank you for your contribution, and sorry for the delays on our side. Great job! 🚀

And feel free to share your achievement on social networks, we’d be happy to amplify it!

@conditionedstimulus
Copy link
Contributor Author

conditionedstimulus commented Feb 4, 2025

@qubvel, @ArthurZucker

Thank you guys!
It was a long and fun journey, and I truly appreciate your support and guidance. I'm glad I could contribute! :)

elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
* initial commit

* encoder+decoder layer changes WIP

* architecture checks

* working version of detection + segmentation

* fix modeling outputs

* fix return dict + output att/hs

* found the position embedding masking bug

* pre-training version

* added iamge processors

* typo in init.py

* iterupdate set to false

* fixed num_labels in class_output linear layer bias init

* multihead attention shape fixes

* test improvements

* test update

* dab-detr model_doc update

* dab-detr model_doc update2

* test fix:test_retain_grad_hidden_states_attentions

* config file clean and renaming variables

* config file clean and renaming variables fix

* updated convert_to_hf file

* small fixes

* style and qulity checks

* return_dict fix

* Merge branch main into add_dab_detr

* small comment fix

* skip test_inputs_embeds test

* image processor updates + image processor test updates

* check copies test fix update

* updates for check_copies.py test

* updates for check_copies.py test2

* tied weights fix

* fixed image processing tests and fixed shared weights issues

* added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py

* delete prints from test file

* SafeTensor modification to solve HF Trainer issue

* removing the safetensor modifications

* make fix copies and hf uplaod has been added.

* fixed index.md

* fixed repo consistency

* styel fix and dabdetrimageprocessor docstring update

* requested modifications after the first review

* Update src/transformers/models/dab_detr/image_processing_dab_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* repo consistency has been fixed

* update copied NestedTensor function after main merge

* Update src/transformers/models/dab_detr/modeling_dab_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* temp commit

* temp commit2

* temp commit 3

* unit tests are fixed

* fixed repo consistency

* updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file.

* temporarialy config modifications and repo consistency fixes

* Put dilation parameter back to config

* pattern embeddings have been added to the rename_keys method

* add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES

* delete FeatureExtractor part from docs.md

* requested modifications in modeling_dab_detr.py

* [run_slow] dab_detr

* deleted last segmentation code part, updated conversion script and changed the hf path in test files

* temp commit of requested modifications

* temp commit of requested modifications 2

* updated config file, resolved codepaths and refactored conversion script

* updated decodelayer block types and refactored conversion script

* style and quality update

* small modifications based on the request

* attentions are refactored

* removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed

* deleted imageprocessor

* fixed conversion script + quality and style

* fixed config_att

* [run_slow] dab_detr

* changing model path in conversion file and in test file

* fix Decoder variable naming

* testing the old loss function

* switched back to the new loss function and testing with the odl attention functions

* switched back to the new last good result modeling file

* moved back to the version when I asked the review

* missing new line at the end of the file

* old version test

* turn back to newest mdoel versino but change image processor

* style fix

* style fix after merge main

* [run_slow] dab_detr

* [run_slow] dab_detr

* added device and type for head bias data part

* [run_slow] dab_detr

* fixed model head bias data fill

* changed test_inference_object_detection_head assertTrues to torch test assert_close

* fixes part 1

* quality update

* self.bbox_embed in decoder has been restored

* changed Assert true torch closeall methods to torch testing assertclose

* modelcard markdown file has been updated

* deleted intemediate list from decoder module

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
* initial commit

* encoder+decoder layer changes WIP

* architecture checks

* working version of detection + segmentation

* fix modeling outputs

* fix return dict + output att/hs

* found the position embedding masking bug

* pre-training version

* added iamge processors

* typo in init.py

* iterupdate set to false

* fixed num_labels in class_output linear layer bias init

* multihead attention shape fixes

* test improvements

* test update

* dab-detr model_doc update

* dab-detr model_doc update2

* test fix:test_retain_grad_hidden_states_attentions

* config file clean and renaming variables

* config file clean and renaming variables fix

* updated convert_to_hf file

* small fixes

* style and qulity checks

* return_dict fix

* Merge branch main into add_dab_detr

* small comment fix

* skip test_inputs_embeds test

* image processor updates + image processor test updates

* check copies test fix update

* updates for check_copies.py test

* updates for check_copies.py test2

* tied weights fix

* fixed image processing tests and fixed shared weights issues

* added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py

* delete prints from test file

* SafeTensor modification to solve HF Trainer issue

* removing the safetensor modifications

* make fix copies and hf uplaod has been added.

* fixed index.md

* fixed repo consistency

* styel fix and dabdetrimageprocessor docstring update

* requested modifications after the first review

* Update src/transformers/models/dab_detr/image_processing_dab_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* repo consistency has been fixed

* update copied NestedTensor function after main merge

* Update src/transformers/models/dab_detr/modeling_dab_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* temp commit

* temp commit2

* temp commit 3

* unit tests are fixed

* fixed repo consistency

* updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file.

* temporarialy config modifications and repo consistency fixes

* Put dilation parameter back to config

* pattern embeddings have been added to the rename_keys method

* add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES

* delete FeatureExtractor part from docs.md

* requested modifications in modeling_dab_detr.py

* [run_slow] dab_detr

* deleted last segmentation code part, updated conversion script and changed the hf path in test files

* temp commit of requested modifications

* temp commit of requested modifications 2

* updated config file, resolved codepaths and refactored conversion script

* updated decodelayer block types and refactored conversion script

* style and quality update

* small modifications based on the request

* attentions are refactored

* removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed

* deleted imageprocessor

* fixed conversion script + quality and style

* fixed config_att

* [run_slow] dab_detr

* changing model path in conversion file and in test file

* fix Decoder variable naming

* testing the old loss function

* switched back to the new loss function and testing with the odl attention functions

* switched back to the new last good result modeling file

* moved back to the version when I asked the review

* missing new line at the end of the file

* old version test

* turn back to newest mdoel versino but change image processor

* style fix

* style fix after merge main

* [run_slow] dab_detr

* [run_slow] dab_detr

* added device and type for head bias data part

* [run_slow] dab_detr

* fixed model head bias data fill

* changed test_inference_object_detection_head assertTrues to torch test assert_close

* fixes part 1

* quality update

* self.bbox_embed in decoder has been restored

* changed Assert true torch closeall methods to torch testing assertclose

* modelcard markdown file has been updated

* deleted intemediate list from decoder module

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants