Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update data models to account for REST API changes in Geti v1.16 #357

Merged
merged 6 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions geti_sdk/data_models/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ class Algorithm:
Representation of a supported algorithm on the Intel® Geti™ platform.
"""

algorithm_name: str
model_size: str
model_template_id: str
gigaflops: float
algorithm_name: Optional[str] = None # Deprecated in Geti v1.16, use 'name' instead
name: Optional[str] = None
summary: Optional[str] = None
domain: Optional[str] = attr.field(
default=None, converter=str_to_optional_enum_converter(Domain)
Expand All @@ -42,7 +43,10 @@ class Algorithm:
default=None, converter=str_to_optional_enum_converter(TaskType)
)
supports_auto_hpo: Optional[bool] = None
recommended_choice: Optional[bool] = None # Added in Geti v1.9
recommended_choice: Optional[bool] = (
None # Deprecated in Geti v1.16, use 'default_algorithm' instead
)
default_algorithm: Optional[bool] = None # Added in Geti v1.16
performance_category: Optional[str] = None # Added in Geti v1.9
lifecycle_stage: Optional[str] = None # Added in Geti v1.9

Expand All @@ -51,6 +55,11 @@ def __attrs_post_init__(self):
Convert domain to task type for backward compatibility with earlier versions of
the Intel® Geti™ platform
"""
if self.default_algorithm is None:
# For older Geti versions, that were still using 'recommended choice'
self.default_algorithm = self.recommended_choice
if self.name is None:
self.name = self.algorithm_name
ljcornel marked this conversation as resolved.
Show resolved Hide resolved
if self.domain is not None and self.task_type is None:
self.task_type = TaskType.from_domain(self.domain)
self.domain = None
Expand Down
2 changes: 1 addition & 1 deletion geti_sdk/data_models/configuration_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def resolve_algorithm(self, algorithm: Algorithm):
:param algorithm: Algorithm instance to which the hyper parameters belong
:return:
"""
self.algorithm_name = algorithm.algorithm_name
self.algorithm_name = algorithm.name
self.model_template_id = algorithm.model_template_id


Expand Down
6 changes: 3 additions & 3 deletions geti_sdk/data_models/containers/algorithm_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def summary(self) -> str:
summary_str = "Algorithms:\n"
for algorithm in self.data:
summary_str += (
f" Name: {algorithm.algorithm_name}\n"
f" Name: {algorithm.name}\n"
f" Task type: {algorithm.task_type}\n"
f" Model size: {algorithm.model_size}\n"
f" Gigaflops: {algorithm.gigaflops}\n\n"
Expand All @@ -114,7 +114,7 @@ def get_by_name(self, name: str) -> Algorithm:
:return: Algorithm holding the algorithm details
"""
for algo in self.data:
if algo.algorithm_name == name:
if algo.name == name:
return algo
raise ValueError(
f"Algorithm named {name} was not found in the "
Expand All @@ -132,7 +132,7 @@ def get_default_for_task_type(self, task_type: TaskType) -> Algorithm:
:return: Default algorithm for the task
"""
task_algos = self.get_by_task_type(task_type=task_type)
default = [algo for algo in task_algos if algo.recommended_choice]
default = [algo for algo in task_algos if algo.default_algorithm]
if len(default) == 1:
return default[0]
else:
Expand Down
10 changes: 10 additions & 0 deletions geti_sdk/data_models/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ class Image(MediaItem):
"""

media_information: ImageInformation = attr.field(kw_only=True)
annotation_scene_id: Optional[str] = attr.field(
kw_only=True, default=None, repr=False
) # Added in Geti v 1.16
roi_id: Optional[str] = attr.field(
kw_only=True, default=None, repr=False
) # Added in Geti v 1.16
editor_name: Optional[str] = attr.field(
kw_only=True, default=None, repr=False
) # Added in Geti v 1.16

def __attrs_post_init__(self):
"""
Expand Down Expand Up @@ -274,6 +283,7 @@ class Video(MediaItem):
"""

media_information: VideoInformation = attr.field(kw_only=True)
matched_frames: Optional[int] = attr.field(kw_only=True, default=None, repr=False)

def __attrs_post_init__(self):
"""
Expand Down
15 changes: 13 additions & 2 deletions geti_sdk/data_models/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,30 @@ class TaskPerformance:
"""
Task Performance metrics in Intel® Geti™.

:var task_node_id: Overall score of the project or model
:var task_node_id: Unique ID of the task to which this Performance metric
applies. Deprecated in Geti v1.16, use `task_id` instead.
:var task_id: Unique ID of the task to which this Performance metric
applies.
:var score: Score of the project or model for each task
:var local_score: Accuracy of the model or project with respect to object
localization for each task
:var global_score: Accuracy of the model or project with respect to global
classification of the full image for each task
"""

task_node_id: Optional[str] = None
task_node_id: Optional[str] = None # Deprecated in Geti v1.16
task_id: Optional[str] = None
score: Optional[Score] = None
local_score: Optional[Score] = None
global_score: Optional[Score] = None

def __attrs_post_init__(self):
"""
Post initialization hook
"""
if self.task_id is None:
self.task_id = self.task_node_id


@attr.define()
class Performance:
Expand Down
1 change: 1 addition & 0 deletions geti_sdk/platform_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,4 @@ def is_geti(self) -> bool:
GETI_15_VERSION = GetiVersion("1.5.0-release-20230504111017")
GETI_18_VERSION = GetiVersion("1.8.0-release-20231018022911")
GETI_114_VERSION = GetiVersion("1.14.0-release-20240131095302")
GETI_116_VERSION = GetiVersion("1.16.0-release-20240320101320")
37 changes: 30 additions & 7 deletions geti_sdk/rest_clients/annotation_clients/base_annotation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from geti_sdk.data_models.media import MediaInformation, MediaItem
from geti_sdk.data_models.project import Dataset
from geti_sdk.http_session import GetiRequestException, GetiSession
from geti_sdk.platform_versions import GETI_116_VERSION
from geti_sdk.rest_clients.dataset_client import DatasetClient
from geti_sdk.rest_converters import AnnotationRESTConverter
from geti_sdk.rest_converters.annotation_rest_converter import (
Expand Down Expand Up @@ -101,17 +102,39 @@ def _get_all_media_in_dataset_by_type(
"""
if media_type == Image:
media_name = "images"
single_media_name = "image"
elif media_type == Video:
media_name = "videos"
single_media_name = "video"
else:
raise ValueError(f"Invalid media type specified: {media_type}.")
get_media_url = (
f"workspaces/{self.workspace_id}/projects/{self._project.id}"
f"/datasets/{dataset.id}/media/"
f"{media_name}?top=500"
)
response = self.session.get_rest_response(url=get_media_url, method="GET")
total_number_of_media: int = response["media_count"][media_name]

if self.session.version < GETI_116_VERSION:
get_media_url = (
f"workspaces/{self.workspace_id}/projects/{self._project.id}"
f"/datasets/{dataset.id}/media/"
f"{media_name}?top=500"
)
response = self.session.get_rest_response(url=get_media_url, method="GET")
total_number_of_media: int = response["media_count"][media_name]
else:
url = (
f"workspaces/{self.workspace_id}/projects/{self._project.id}"
f"/datasets/{dataset.id}/media:query?top=500"
)
data = {
"condition": "and",
"rules": [
{
"field": "MEDIA_TYPE",
"operator": "EQUAL",
"value": single_media_name,
}
],
}
response = self.session.get_rest_response(url=url, method="POST", data=data)
total_number_of_media: int = response[f"total_matched_{media_name}"]

raw_media_list: List[Dict[str, Any]] = []
while len(raw_media_list) < total_number_of_media:
for media_item_dict in response["media"]:
Expand Down
2 changes: 1 addition & 1 deletion geti_sdk/rest_clients/configuration_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_for_task_and_algorithm(self, task: Task, algorithm: Algorithm):
"""
if algorithm not in self.supported_algos.get_by_task_type(task.type):
raise ValueError(
f"The requested algorithm '{algorithm.algorithm_name}' is not "
f"The requested algorithm '{algorithm.name}' is not "
f"supported for a task of type '{task.type}'. Unable to retrieve "
f"configuration."
)
Expand Down
5 changes: 4 additions & 1 deletion geti_sdk/rest_clients/dataset_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from geti_sdk.data_models import Dataset, Project
from geti_sdk.http_session import GetiSession
from geti_sdk.platform_versions import GETI_116_VERSION
from geti_sdk.utils import deserialize_dictionary


Expand All @@ -37,7 +38,9 @@ def create_dataset(self, name: str) -> Dataset:
:param name: Name of the dataset to create
:return: The created dataset
"""
request_data = {"name": name, "use_for_training": False}
request_data = {"name": name}
if self.session.version < GETI_116_VERSION:
request_data.update({"use_for_training": False})
response = self.session.get_rest_response(
url=self.base_url,
method="POST",
Expand Down
27 changes: 23 additions & 4 deletions geti_sdk/rest_clients/media_client/media_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from geti_sdk.data_models.project import Dataset
from geti_sdk.data_models.utils import numpy_from_buffer
from geti_sdk.http_session import GetiRequestException, GetiSession
from geti_sdk.platform_versions import GETI_116_VERSION
from geti_sdk.rest_clients.dataset_client import DatasetClient
from geti_sdk.rest_converters.media_rest_converter import MediaRESTConverter

Expand Down Expand Up @@ -110,10 +111,28 @@ def _get_all(self, dataset: Optional[Dataset] = None) -> MediaList[MediaTypeVar]
if dataset is None:
dataset = self._project.training_dataset

response = self.session.get_rest_response(
url=f"{self.base_url(dataset=dataset)}?top=500", method="GET"
)
total_number_of_media: int = response["media_count"][self.plural_media_name]
if self.session.version < GETI_116_VERSION:
response = self.session.get_rest_response(
url=f"{self.base_url(dataset=dataset)}?top=500", method="GET"
)
total_number_of_media: int = response["media_count"][self.plural_media_name]
else:
url = f"{self._base_url}/{dataset.id}/media:query?top=500"
data = {
"condition": "and",
"rules": [
{
"field": "MEDIA_TYPE",
"operator": "EQUAL",
"value": f"{self._MEDIA_TYPE}",
}
],
}
response = self.session.get_rest_response(url=url, method="POST", data=data)
total_number_of_media: int = response[
f"total_matched_{self.plural_media_name}"
]

raw_media_list: List[Dict[str, Any]] = []
while len(raw_media_list) < total_number_of_media:
for media_item_dict in response["media"]:
Expand Down
10 changes: 3 additions & 7 deletions geti_sdk/rest_clients/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def get_model_group_by_algo_name(self, algorithm_name: str) -> Optional[ModelGro
"""
model_groups = self.get_all_model_groups()
return next(
(
group
for group in model_groups
if group.algorithm.algorithm_name == algorithm_name
),
(group for group in model_groups if group.algorithm.name == algorithm_name),
None,
)

Expand Down Expand Up @@ -321,7 +317,7 @@ def set_active_model(
if isinstance(algorithm, str):
algorithm_name = algorithm
elif isinstance(algorithm, Algorithm):
algorithm_name = algorithm.algorithm_name
algorithm_name = algorithm.name
else:
raise ValueError(
f"Invalid type {type(algorithm)}. Argument `algorithm` must be "
Expand All @@ -333,7 +329,7 @@ def set_active_model(
)
# Now we make sure that the algorithm is supported in the project
algorithms_supported_in_the_project = {
algorithm.algorithm_name
algorithm.name
for task in self.project.get_trainable_tasks()
for algorithm in self.supported_algos.get_by_task_type(task.type)
}
Expand Down
23 changes: 17 additions & 6 deletions geti_sdk/rest_clients/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from geti_sdk.data_models.enums import JobState, JobType
from geti_sdk.data_models.project import Dataset
from geti_sdk.http_session import GetiSession
from geti_sdk.platform_versions import GETI_11_VERSION, GETI_18_VERSION
from geti_sdk.platform_versions import (
GETI_11_VERSION,
GETI_18_VERSION,
GETI_116_VERSION,
)
from geti_sdk.rest_converters import (
ConfigurationRESTConverter,
JobRESTConverter,
Expand Down Expand Up @@ -177,7 +181,6 @@ def train_task(
dataset: Optional[Dataset] = None,
algorithm: Optional[Algorithm] = None,
train_from_scratch: bool = False,
enable_pot_optimization: bool = False,
hyper_parameters: Optional[TaskConfiguration] = None,
hpo_parameters: Optional[Dict[str, Any]] = None,
await_running_jobs: bool = True,
Expand All @@ -196,8 +199,6 @@ def train_task(
default), the default algorithm for the task will be used.
:param train_from_scratch: True to train the model from scratch, False to
continue training from an existing checkpoint (if any)
:param enable_pot_optimization: True to optimize the trained model with POT
after training is complete
:param hyper_parameters: Optional hyper parameters to use for training
:param hpo_parameters: Optional set of parameters to use for automatic hyper
parameter optimization. Only supported for version 1.1 and up
Expand All @@ -215,15 +216,20 @@ def train_task(
"""
if isinstance(task, int):
task = self.project.get_trainable_tasks()[task]
if dataset is not None and self.session.version >= GETI_116_VERSION:
logging.warning(
"Training on a Dataset other than the default training dataset is not "
"supported in the version of the Geti platform running on your server. "
"The `dataset` parameter will be disregarded."
)
if dataset is None:
dataset = self.project.datasets[0]
dataset = self.project.training_dataset
if algorithm is None:
algorithm = self.supported_algos.get_default_for_task_type(task.type)
request_data: Dict[str, Any] = {
"dataset_id": dataset.id,
"task_id": task.id,
"train_from_scratch": train_from_scratch,
"enable_pot_optimization": enable_pot_optimization,
"model_template_id": algorithm.model_template_id,
}
if hyper_parameters is not None:
Expand All @@ -242,6 +248,9 @@ def train_task(

if self.session.version.is_sc_1_1 or self.session.version.is_sc_mvp:
data = [request_data]
elif self.session.version >= GETI_116_VERSION:
request_data.pop("dataset_id")
data = request_data
else:
data = {"training_parameters": [request_data]}

Expand Down Expand Up @@ -278,6 +287,8 @@ def train_task(
if self.session.version < GETI_11_VERSION:
job = JobRESTConverter.from_dict(response)
job_id = job.id
elif self.session.version >= GETI_116_VERSION:
job_id = response["job_id"]
else:
job_id = response["job_ids"][0]
try:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/007_train_project.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
"source": [
"algorithm = available_algorithms.get_default_for_task_type(task.type)\n",
"\n",
"print(f\"Default algorithm for `{task.type}` task: `{algorithm.algorithm_name}`.\\n\")\n",
"print(f\"Default algorithm for `{task.type}` task: `{algorithm.name}`.\\n\")\n",
"print(algorithm.overview)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/use_cases/103_parking_lot_train2deployment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,9 @@
"source": [
"latest_trained_models = []\n",
"for algo in available_algorithms:\n",
" model = model_client.get_latest_model_by_algo_name(algo.algorithm_name)\n",
" model = model_client.get_latest_model_by_algo_name(algo.name)\n",
" if model is not None:\n",
" print(f\"Retrieved latest trained model for algorithm {algo.algorithm_name}\")\n",
" print(f\"Retrieved latest trained model for algorithm {algo.name}\")\n",
" latest_trained_models.append(model)\n",
"\n",
"# For each algorithm, grab the optimized models that are available. These will be used for local deployment\n",
Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading
Loading