From 235be08569000a5361354f766972e653212bf0d3 Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:57:30 +0900 Subject: [PATCH] [DETA] fix backbone freeze/unfreeze function (#27843) * [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/deta/modeling_deta.py | 6 ++-- tests/models/deta/test_modeling_deta.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 0f245f2a305..19f0250e6f8 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1414,14 +1414,12 @@ def get_encoder(self): def get_decoder(self): return self.decoder - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone def freeze_backbone(self): - for name, param in self.backbone.conv_encoder.model.named_parameters(): + for name, param in self.backbone.model.named_parameters(): param.requires_grad_(False) - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone def unfreeze_backbone(self): - for name, param in self.backbone.conv_encoder.model.named_parameters(): + for name, param in self.backbone.model.named_parameters(): param.requires_grad_(True) # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index d5bf32acaba..8581723ccb3 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -162,6 +162,26 @@ def create_and_check_deta_model(self, config, pixel_values, pixel_mask, labels): self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size)) + def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels): + model = DetaModel(config=config) + model.to(torch_device) + model.eval() + + model.freeze_backbone() + + for _, param in model.backbone.model.named_parameters(): + self.parent.assertEqual(False, param.requires_grad) + + def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels): + model = DetaModel(config=config) + model.to(torch_device) + model.eval() + + model.unfreeze_backbone() + + for _, param in model.backbone.model.named_parameters(): + self.parent.assertEqual(True, param.requires_grad) + def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): model = DetaForObjectDetection(config=config) model.to(torch_device) @@ -250,6 +270,14 @@ def test_deta_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_deta_model(*config_and_inputs) + def test_deta_freeze_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs) + + def test_deta_unfreeze_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs) + def test_deta_object_detection_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)