Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Versions are taggable #2102

Merged
merged 12 commits into from
Dec 13, 2023
33 changes: 28 additions & 5 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _model_version_to_print(
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"tags": [t.name for t in model_version.tags],
"data_artifacts_count": len(model_version.data_artifact_ids),
"model_artifacts_count": len(model_version.model_artifact_ids),
"endpoint_artifacts_count": len(model_version.endpoint_artifact_ids),
Expand Down Expand Up @@ -393,6 +394,22 @@ def list_model_versions(model_name_or_id: str, **kwargs: Any) -> None:
type=click.Choice(choices=ModelStages.values()),
help="The stage of the model version.",
)
@click.option(
"--tag",
"-t",
help="Tags to be added to the model.",
type=str,
required=False,
multiple=True,
)
@click.option(
"--remove-tag",
"-r",
help="Tags to be removed from the model.",
type=str,
required=False,
multiple=True,
)
@click.option(
"--force",
"-f",
Expand All @@ -403,6 +420,8 @@ def update_model_version(
model_name_or_id: str,
model_version_name_or_number_or_id: str,
stage: str,
tag: Optional[List[str]],
remove_tag: Optional[List[str]],
force: bool = False,
) -> None:
"""Update an existing model version stage in the Model Control Plane.
Expand All @@ -411,17 +430,21 @@ def update_model_version(
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The ID, number or name of the model version.
stage: The stage of the model version to be set.
tag: Tags to be added to the model version.
remove_tag: Tags to be removed from the model version.
force: Whether existing model version in target stage should be silently archived.
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
try:
Client().update_model_version(
model_version = Client().update_model_version(
model_name_or_id=model_name_or_id,
version_name_or_id=model_version.id,
stage=stage,
add_tags=tag,
remove_tags=remove_tag,
force=force,
)
except RuntimeError:
Expand All @@ -438,15 +461,15 @@ def update_model_version(
if not confirmation:
cli_utils.declare("Model version stage update canceled.")
return
Client().update_model_version(
model_version = Client().update_model_version(
model_name_or_id=model_version.model.id,
version_name_or_id=model_version.id,
stage=stage,
add_tags=tag,
remove_tags=remove_tag,
force=True,
)
cli_utils.declare(
f"Model version '{model_version.name}' stage updated to '{stage}'."
)
cli_utils.print_table([_model_version_to_print(model_version)])


@version.command("delete", help="Delete an existing model version.")
Expand Down
9 changes: 9 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4556,13 +4556,15 @@ def create_model_version(
model_name_or_id: Union[str, UUID],
name: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> ModelVersionResponseModel:
"""Creates a new model version in Model Control Plane.

Args:
model_name_or_id: the name or id of the model to create model version in.
name: the name of the Model Version to be created.
description: the description of the Model Version to be created.
tags: Tags associated with the model.

Returns:
The newly created model version.
Expand All @@ -4576,6 +4578,7 @@ def create_model_version(
user=self.active_user.id,
workspace=self.active_workspace.id,
model=model_name_or_id,
tags=tags,
)
)

Expand Down Expand Up @@ -4730,6 +4733,8 @@ def update_model_version(
stage: Optional[Union[str, ModelStages]] = None,
force: bool = False,
name: Optional[str] = None,
add_tags: Optional[List[str]] = None,
remove_tags: Optional[List[str]] = None,
) -> ModelVersionResponseModel:
"""Get all model versions by filter.

Expand All @@ -4740,6 +4745,8 @@ def update_model_version(
force: Whether existing model version in target stage should be
silently archived or an error should be raised.
name: Target model version name to be set.
add_tags: Tags to add to the model version.
remove_tags: Tags to remove from to the model version.

Returns:
An updated model version.
Expand All @@ -4758,6 +4765,8 @@ def update_model_version(
stage=stage,
force=force,
name=name,
add_tags=add_tags,
remove_tags=remove_tags,
),
)

Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class TaggableResourceTypes(StrEnum):
ARTIFACT = "artifact"
ARTIFACT_VERSION = "artifact_version"
MODEL = "model"
MODEL_VERSION = "model_version"


class ResponseUpdateStrategy(StrEnum):
Expand Down
1 change: 1 addition & 0 deletions src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def _get_or_create_model_version(self) -> "ModelVersionResponseModel":
name=self.version,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
Expand Down
10 changes: 9 additions & 1 deletion src/zenml/models/model_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class ModelVersionRequestModel(
model: UUID = Field(
description="The ID of the model containing version",
)
tags: Optional[List[str]] = Field(
title="Tags associated with the model version",
)


class ModelVersionResponseModel(
Expand Down Expand Up @@ -158,6 +161,9 @@ class ModelVersionResponseModel(
description="Pipeline runs linked to the model version",
default={},
)
tags: List[TagResponseModel] = Field(
title="Tags associated with the model version", default=[]
)

def to_model_version(
self,
Expand All @@ -184,7 +190,7 @@ def to_model_version(
limitations=self.model.limitations,
trade_offs=self.model.trade_offs,
ethics=self.model.ethics,
tags=[t.name for t in self.model.tags],
tags=[t.name for t in self.tags],
version=self.name,
was_created_in_this_run=was_created_in_this_run,
suppress_class_validation_warnings=suppress_class_validation_warnings,
Expand Down Expand Up @@ -450,6 +456,8 @@ class ModelVersionUpdateModel(BaseModel):
name: Optional[str] = Field(
description="Target model version name to be set", default=None
)
add_tags: Optional[List[str]] = None
remove_tags: Optional[List[str]] = None

@validator("stage")
def _validate_stage(cls, stage: str) -> str:
Expand Down
9 changes: 9 additions & 0 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ class ModelVersionSchema(NamedSchema, table=True):
back_populates="model_version",
sa_relationship_kwargs={"cascade": "delete"},
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="model_version",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
overlaps="tags",
),
)

number: int = Field(sa_column=Column(INTEGER, nullable=False))
description: str = Field(sa_column=Column(TEXT, nullable=True))
Expand Down Expand Up @@ -310,6 +318,7 @@ def to_model(
endpoint_artifact_ids=endpoint_artifact_ids,
data_artifact_ids=data_artifact_ids,
pipeline_run_ids=pipeline_run_ids,
tags=[t.tag.to_model() for t in self.tags],
)

def update(
Expand Down
18 changes: 14 additions & 4 deletions src/zenml/zen_stores/schemas/tag_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
ArtifactSchema,
ArtifactVersionSchema,
)
from zenml.zen_stores.schemas.model_schemas import ModelSchema
from zenml.zen_stores.schemas.model_schemas import (
ModelSchema,
ModelVersionSchema,
)


class TagSchema(NamedSchema, table=True):
Expand Down Expand Up @@ -126,21 +129,28 @@ class TagResourceSchema(BaseSchema, table=True):
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
overlaps="tags,model,artifact_version",
overlaps="tags,model,artifact_version,model_version",
),
)
artifact_version: List["ArtifactVersionSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
overlaps="tags,model,artifact",
overlaps="tags,model,artifact,model_version",
),
)
model: List["ModelSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
overlaps="tags,artifact,artifact_version",
overlaps="tags,artifact,artifact_version,model_version",
),
)
model_version: List["ModelVersionSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
overlaps="tags,model,artifact,artifact_version",
),
)

Expand Down
23 changes: 23 additions & 0 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6332,6 +6332,13 @@ def create_model_version(
)
session.add(model_version_schema)

if model_version.tags:
self._attach_tags_to_resource(
tag_names=model_version.tags,
resource_id=model_version_schema.id,
resource_type=TaggableResourceTypes.MODEL_VERSION,
)

session.commit()
return model_version_schema.to_model()

Expand Down Expand Up @@ -6482,6 +6489,22 @@ def update_model_version(
f"Model version {existing_model_version_in_target_stage.name} has been set to {ModelStages.ARCHIVED.value}."
)

if model_version_update_model.add_tags:
self._attach_tags_to_resource(
tag_names=model_version_update_model.add_tags,
resource_id=existing_model_version.id,
resource_type=TaggableResourceTypes.MODEL_VERSION,
)
model_version_update_model.add_tags = None
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
if model_version_update_model.remove_tags:
self._detach_tags_from_resource(
tag_names=model_version_update_model.remove_tags,
resource_id=existing_model_version.id,
resource_type=TaggableResourceTypes.MODEL_VERSION,
)

model_version_update_model.remove_tags = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a general question here: I started seeing this pattern with the other taggable entities as well. In the earlier iterations, we have used the update method of a Schema to apply such changes. Lately, I see that the update is happening one layer up on the SqlZenStore level instead of the Schema level. Is there a reason why we see this switch when it comes to models and artifacts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that there is a real explanation behind that. Can you share some Schemas having such "extended" updates? For instance in run_metadata we use an extra method to create those, so it is not really closely coupled with the object itself (e.g. create_run_metadata is called as one of the steps in save_artifact utility).
It might turn out that having those very related entities coupled in one update method is a bit better choice, IMO. But it is definitely not a battle I want to fight in, so if you share examples and reject my explanations - I will rework 👍🏼

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in their first iterations, most schemas had an update function which was relatively more extended. However, as I went through the current implementation, I see that this is not the case anymore. I don't think we need to rework anything in the scope of this PR though, ideally, this would be a separate cleaning-up ticket in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need examples in any case:

  • Code repository schemas: The entire update on the schema happens within the update method not in the SQLZenStore.
  • Stack component schemas: The update on the configuration and labels are happening inside the update function whereas the update on the connectors is happening on the SQLZenStore.,

This is the inconsistency that I find suboptimal. IMO, the update should either happen fully on the SQLZenStore or fully on the update method of a schema. Otherwise, the code becomes very hard to track and read.

existing_model_version.update(
target_stage=stage,
target_name=model_version_update_model.name,
Expand Down
42 changes: 40 additions & 2 deletions tests/integration/functional/model/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,31 @@ def test_model_exists(self):
def test_model_create_model_and_version(self):
"""Test if model and version are created, not existing before."""
with ModelContext(create_model=False):
mv = ModelVersion(name=MODEL_NAME)
mv = ModelVersion(name=MODEL_NAME, tags=["tag1", "tag2"])
with mock.patch("zenml.model.model_version.logger.info") as logger:
mv = mv._get_or_create_model_version()
logger.assert_called()
assert mv.name == str(mv.number)
assert mv.model.name == MODEL_NAME
assert {t.name for t in mv.tags} == {"tag1", "tag2"}
assert {t.name for t in mv.model.tags} == {"tag1", "tag2"}

def test_create_model_version_makes_proper_tagging(self):
"""Test if model versions get unique tags."""
with ModelContext(create_model=False):
mv = ModelVersion(name=MODEL_NAME, tags=["tag1", "tag2"])
mv = mv._get_or_create_model_version()
assert mv.name == str(mv.number)
assert mv.model.name == MODEL_NAME
assert {t.name for t in mv.tags} == {"tag1", "tag2"}
assert {t.name for t in mv.model.tags} == {"tag1", "tag2"}

mv = ModelVersion(name=MODEL_NAME, tags=["tag3", "tag4"])
mv = mv._get_or_create_model_version()
assert mv.name == str(mv.number)
assert mv.model.name == MODEL_NAME
assert {t.name for t in mv.tags} == {"tag3", "tag4"}
assert {t.name for t in mv.model.tags} == {"tag1", "tag2"}

def test_model_fetch_model_and_version_by_number(self):
"""Test model and model version retrieval by exact version number."""
Expand Down Expand Up @@ -201,7 +220,7 @@ def test_tags_properly_updated(self):
tags=["foo", "bar"],
delete_new_version_on_failure=False,
)
model_id = mv._get_or_create_model().id
model_id = mv._get_or_create_model_version().model.id

Client().update_model(model_id, add_tags=["tag1", "tag2"])
model = mv._get_or_create_model()
Expand All @@ -213,7 +232,26 @@ def test_tags_properly_updated(self):
"tag2",
}

Client().update_model_version(
model_id, "1", add_tags=["tag3", "tag4"]
)
model_version = mv._get_or_create_model_version()
assert len(model_version.tags) == 4
assert {t.name for t in model_version.tags} == {
"foo",
"bar",
"tag3",
"tag4",
}

Client().update_model(model_id, remove_tags=["tag1", "tag2"])
model = mv._get_or_create_model()
assert len(model.tags) == 2
assert {t.name for t in model.tags} == {"foo", "bar"}

Client().update_model_version(
model_id, "1", remove_tags=["tag3", "tag4"]
)
model_version = mv._get_or_create_model_version()
assert len(model_version.tags) == 2
assert {t.name for t in model_version.tags} == {"foo", "bar"}
Loading
Loading