Skip to content

Commit

Permalink
add freeze/unfreeze test case in DETA
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Dec 8, 2023
1 parent 4fc204f commit a583fa7
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ def create_and_check_deta_model(self, config, pixel_values, pixel_mask, labels):

self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))

def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()

model.freeze_backbone()

for _, param in model.backbone.model.named_parametesr():
self.parent.assertEqual(False, param.requires_grad)

def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()

model.unfreeze_backbone()

for _, param in model.backbone.model.named_parametesr():
self.parent.assertEqual(True, param.requires_grad)

def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = DetaForObjectDetection(config=config)
model.to(torch_device)
Expand Down Expand Up @@ -250,6 +270,14 @@ def test_deta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_model(*config_and_inputs)

def test_deta_freeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)

def test_deta_unfreeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_unfreeze_backbone(*config_and_inputs)

def test_deta_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
Expand Down

0 comments on commit a583fa7

Please sign in to comment.