Skip to content

Commit

Permalink
Merge pull request #443 from openvinotoolkit/include-model-group-id-i…
Browse files Browse the repository at this point in the history
…n-model-serialization

Include `model_group_id` in `Model` serialization and deserialization
  • Loading branch information
ljcornel authored Jun 18, 2024
2 parents 067302e + ec9baf6 commit 34eb5f5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
15 changes: 5 additions & 10 deletions geti_sdk/data_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,17 @@ def base_url(self, base_url: str):
model._base_url = base_url + f"/optimized_models/{model.id}"
self._base_url = base_url

@model_group_id.setter
def model_group_id(self, id_: str):
"""
Set the model group id for this model.
:param id: ID to set
"""
self._model_group_id = id_

def to_dict(self) -> Dict[str, Any]:
"""
Return the dictionary representation of the model.
:return:
"""
return attr.asdict(self, recurse=True, value_serializer=attr_value_serializer)
base_dict = attr.asdict(
self, recurse=True, value_serializer=attr_value_serializer
)
base_dict["model_group_id"] = self.model_group_id
return base_dict

@property
def overview(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion geti_sdk/deployment/deployed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ def from_model_and_hypers(
"""
model_dict = model.to_dict()
model_dict.update({"hyper_parameters": hyper_parameters})
model_group_id = model_dict.pop("model_group_id", None)
deployed_model = cls(**model_dict)
try:
model_group_id = model.model_group_id
base_url = model.base_url
deployed_model.model_group_id = model_group_id
deployed_model.base_url = base_url
Expand Down
10 changes: 8 additions & 2 deletions geti_sdk/rest_converters/model_rest_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def model_from_dict(input_dict: Dict[str, Any]) -> Model:
Intel® Geti™ /model_groups/models REST endpoint
:return: Model object corresponding to the data in `input_dict`
"""
return deserialize_dictionary(input_dict, output_type=Model)
model_group_id = input_dict.pop("model_group_id", None)
model_object = deserialize_dictionary(input_dict, output_type=Model)
model_object.model_group_id = model_group_id
return model_object

@staticmethod
def optimized_model_from_dict(input_dict: Dict[str, Any]) -> OptimizedModel:
Expand All @@ -56,4 +59,7 @@ def optimized_model_from_dict(input_dict: Dict[str, Any]) -> OptimizedModel:
the Intel® Geti™ /model_groups/models REST endpoint
:return: OptimizedModel object corresponding to the data in `input_dict`
"""
return deserialize_dictionary(input_dict, output_type=OptimizedModel)
model_group_id = input_dict.pop("model_group_id", None)
model_object = deserialize_dictionary(input_dict, output_type=OptimizedModel)
model_object.model_group_id = model_group_id
return model_object

0 comments on commit 34eb5f5

Please sign in to comment.