diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 997ea340a77..f8f1b62c353 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2023.11.23-3-g11d8368 +_commit: 2023.11.24 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: '' diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index 867284b1361..4fae2858900 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -882,6 +882,75 @@ zenml code-repository delete ``` +Administering your Models +---------------------------- + +ZenML provides several CLI commands to help you administer your models and +model versions as part of the Model Control Plane. + +To register a new model, you can use the following CLI command: +```bash +zenml model register --name [--MODEL_OPTIONS] +``` + +To list all registered models, use: +```bash +zenml model list +``` + +To update a model, use: +```bash +zenml model update [--MODEL_OPTIONS] +``` + +If you would like to add or remove tags from the model, use: +```bash +zenml model update --tag --tag .. + --remove-tag --remove-tag .. +``` + +To delete a model, use: +```bash +zenml model delete +``` + +The CLI interface for models also helps to navigate through artifacts linked to a specific model versions. +```bash +zenml model data_artifacts [-v ] +zenml model endpoint_artifacts [-v ] +zenml model model_artifacts [-v ] +``` + +You can also navigate the pipeline runs linked to a specific model versions: +```bash +zenml model runs [-v ] +``` + +To list the model versions of a specific model, use: +```bash +zenml model version list +``` + +To delete a model version, use: +```bash +zenml model version delete +``` + +To update a model version, use: +```bash +zenml model version update [--MODEL_VERSION_OPTIONS] +``` +These are some of the more common uses of model version updates: +- stage (i.e. promotion) +```bash +zenml model version update --stage +``` +- tags +```bash +zenml model version update --tag --tag .. + --remove-tag --remove-tag .. +``` + Administering your Pipelines ---------------------------- diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index bbf386f10dc..a0263d10b66 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -60,6 +60,7 @@ def _model_version_to_print( "number": model_version.number, "description": model_version.description, "stage": model_version.stage, + "tags": [t.name for t in model_version.tags], "data_artifacts_count": len(model_version.data_artifact_ids), "model_artifacts_count": len(model_version.model_artifact_ids), "endpoint_artifacts_count": len(model_version.endpoint_artifact_ids), @@ -390,6 +391,22 @@ def list_model_versions(model_name_or_id: str, **kwargs: Any) -> None: type=click.Choice(choices=ModelStages.values()), help="The stage of the model version.", ) +@click.option( + "--tag", + "-t", + help="Tags to be added to the model.", + type=str, + required=False, + multiple=True, +) +@click.option( + "--remove-tag", + "-r", + help="Tags to be removed from the model.", + type=str, + required=False, + multiple=True, +) @click.option( "--force", "-f", @@ -400,6 +417,8 @@ def update_model_version( model_name_or_id: str, model_version_name_or_number_or_id: str, stage: str, + tag: Optional[List[str]], + remove_tag: Optional[List[str]], force: bool = False, ) -> None: """Update an existing model version stage in the Model Control Plane. @@ -408,6 +427,8 @@ def update_model_version( model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The ID, number or name of the model version. stage: The stage of the model version to be set. + tag: Tags to be added to the model version. + remove_tag: Tags to be removed from the model version. force: Whether existing model version in target stage should be silently archived. """ model_version = Client().get_model_version( @@ -415,10 +436,12 @@ def update_model_version( model_version_name_or_number_or_id=model_version_name_or_number_or_id, ) try: - Client().update_model_version( + model_version = Client().update_model_version( model_name_or_id=model_name_or_id, version_name_or_id=model_version.id, stage=stage, + add_tags=tag, + remove_tags=remove_tag, force=force, ) except RuntimeError: @@ -435,15 +458,15 @@ def update_model_version( if not confirmation: cli_utils.declare("Model version stage update canceled.") return - Client().update_model_version( + model_version = Client().update_model_version( model_name_or_id=model_version.model.id, version_name_or_id=model_version.id, stage=stage, + add_tags=tag, + remove_tags=remove_tag, force=True, ) - cli_utils.declare( - f"Model version '{model_version.name}' stage updated to '{stage}'." - ) + cli_utils.print_table([_model_version_to_print(model_version)]) @version.command("delete", help="Delete an existing model version.") diff --git a/src/zenml/client.py b/src/zenml/client.py index d87344e6b31..d2a2a0f201f 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -4569,6 +4569,7 @@ def create_model_version( model_name_or_id: Union[str, UUID], name: Optional[str] = None, description: Optional[str] = None, + tags: Optional[List[str]] = None, ) -> ModelVersionResponse: """Creates a new model version in Model Control Plane. @@ -4577,6 +4578,7 @@ def create_model_version( version in. name: the name of the Model Version to be created. description: the description of the Model Version to be created. + tags: Tags associated with the model. Returns: The newly created model version. @@ -4590,6 +4592,7 @@ def create_model_version( user=self.active_user.id, workspace=self.active_workspace.id, model=model_name_or_id, + tags=tags, ) ) @@ -4747,6 +4750,8 @@ def update_model_version( stage: Optional[Union[str, ModelStages]] = None, force: bool = False, name: Optional[str] = None, + add_tags: Optional[List[str]] = None, + remove_tags: Optional[List[str]] = None, ) -> ModelVersionResponse: """Get all model versions by filter. @@ -4757,6 +4762,8 @@ def update_model_version( force: Whether existing model version in target stage should be silently archived or an error should be raised. name: Target model version name to be set. + add_tags: Tags to add to the model version. + remove_tags: Tags to remove from to the model version. Returns: An updated model version. @@ -4775,6 +4782,8 @@ def update_model_version( stage=stage, force=force, name=name, + add_tags=add_tags, + remove_tags=remove_tags, ), ) diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 10c27c8f6d0..1ed406e4a45 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -320,6 +320,7 @@ class TaggableResourceTypes(StrEnum): ARTIFACT = "artifact" ARTIFACT_VERSION = "artifact_version" MODEL = "model" + MODEL_VERSION = "model_version" class ResponseUpdateStrategy(StrEnum): diff --git a/src/zenml/model/model_version.py b/src/zenml/model/model_version.py index 9a49a4de85a..aaf260b83ee 100644 --- a/src/zenml/model/model_version.py +++ b/src/zenml/model/model_version.py @@ -462,6 +462,7 @@ def _get_or_create_model_version(self) -> "ModelVersionResponse": name=self.version, description=self.description, model=model.id, + tags=self.tags, ) mv_request = ModelVersionRequest.parse_obj(model_version_request) try: diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 85e077b2401..3a72e00340b 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -14,13 +14,14 @@ """Models representing model versions.""" from datetime import datetime -from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import BaseModel, Field, PrivateAttr, validator from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ModelStages +from zenml.models.tag_models import TagResponseModel from zenml.models.v2.base.filter import AnyQuery from zenml.models.v2.base.scoped import ( WorkspaceScopedFilter, @@ -68,6 +69,9 @@ class ModelVersionRequest(WorkspaceScopedRequest): model: UUID = Field( description="The ID of the model containing version", ) + tags: Optional[List[str]] = Field( + title="Tags associated with the model version", + ) # ------------------ Update Model ------------------ @@ -88,7 +92,16 @@ class ModelVersionUpdate(BaseModel): default=False, ) name: Optional[str] = Field( - description="Target model version name to be set", default=None + description="Target model version name to be set", + default=None, + ) + add_tags: Optional[List[str]] = Field( + description="Tags to be added to the model version", + default=None, + ) + remove_tags: Optional[List[str]] = Field( + description="Tags to be removed from the model version", + default=None, ) @validator("stage") @@ -134,6 +147,9 @@ class ModelVersionResponseBody(WorkspaceScopedResponseBody): description="Pipeline runs linked to the model version", default={}, ) + tags: List[TagResponseModel] = Field( + title="Tags associated with the model version", default=[] + ) created: datetime = Field( title="The timestamp when this component was created." ) @@ -246,6 +262,15 @@ def updated(self) -> datetime: """ return self.get_body().updated + @property + def tags(self) -> List["TagResponseModel"]: + """The `tags` property. + + Returns: + the value of the property. + """ + return self.get_body().tags + @property def description(self) -> Optional[str]: """The `description` property. @@ -293,7 +318,7 @@ def to_model_version( limitations=self.model.limitations, trade_offs=self.model.trade_offs, ethics=self.model.ethics, - tags=[t.name for t in self.model.tags], + tags=[t.name for t in self.tags], version=self.name, was_created_in_this_run=was_created_in_this_run, suppress_class_validation_warnings=suppress_class_validation_warnings, diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 799e6fba69c..9891286bec7 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -241,6 +241,14 @@ class ModelVersionSchema(NamedSchema, table=True): back_populates="model_version", sa_relationship_kwargs={"cascade": "delete"}, ) + tags: List["TagResourceSchema"] = Relationship( + back_populates="model_version", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)", + cascade="delete", + overlaps="tags", + ), + ) number: int = Field(sa_column=Column(INTEGER, nullable=False)) description: str = Field(sa_column=Column(TEXT, nullable=True)) @@ -331,6 +339,7 @@ def to_model( data_artifact_ids=data_artifact_ids, endpoint_artifact_ids=endpoint_artifact_ids, pipeline_run_ids=pipeline_run_ids, + tags=[t.tag.to_model() for t in self.tags], ) return ModelVersionResponse( diff --git a/src/zenml/zen_stores/schemas/tag_schemas.py b/src/zenml/zen_stores/schemas/tag_schemas.py index 95466122cf2..75719790969 100644 --- a/src/zenml/zen_stores/schemas/tag_schemas.py +++ b/src/zenml/zen_stores/schemas/tag_schemas.py @@ -37,7 +37,10 @@ ArtifactSchema, ArtifactVersionSchema, ) - from zenml.zen_stores.schemas.model_schemas import ModelSchema + from zenml.zen_stores.schemas.model_schemas import ( + ModelSchema, + ModelVersionSchema, + ) class TagSchema(NamedSchema, table=True): @@ -126,21 +129,28 @@ class TagResourceSchema(BaseSchema, table=True): back_populates="tags", sa_relationship_kwargs=dict( primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)", - overlaps="tags,model,artifact_version", + overlaps="tags,model,artifact_version,model_version", ), ) artifact_version: List["ArtifactVersionSchema"] = Relationship( back_populates="tags", sa_relationship_kwargs=dict( primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="tags,model,artifact", + overlaps="tags,model,artifact,model_version", ), ) model: List["ModelSchema"] = Relationship( back_populates="tags", sa_relationship_kwargs=dict( primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)", - overlaps="tags,artifact,artifact_version", + overlaps="tags,artifact,artifact_version,model_version", + ), + ) + model_version: List["ModelVersionSchema"] = Relationship( + back_populates="tags", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)", + overlaps="tags,model,artifact,artifact_version", ), ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index adddde1bab6..7b12e4c0edf 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -6403,6 +6403,13 @@ def create_model_version( ) session.add(model_version_schema) + if model_version.tags: + self._attach_tags_to_resource( + tag_names=model_version.tags, + resource_id=model_version_schema.id, + resource_type=TaggableResourceTypes.MODEL_VERSION, + ) + session.commit() return model_version_schema.to_model(hydrate=True) @@ -6562,6 +6569,19 @@ def update_model_version( f"Model version {existing_model_version_in_target_stage.name} has been set to {ModelStages.ARCHIVED.value}." ) + if model_version_update_model.add_tags: + self._attach_tags_to_resource( + tag_names=model_version_update_model.add_tags, + resource_id=existing_model_version.id, + resource_type=TaggableResourceTypes.MODEL_VERSION, + ) + if model_version_update_model.remove_tags: + self._detach_tags_from_resource( + tag_names=model_version_update_model.remove_tags, + resource_id=existing_model_version.id, + resource_type=TaggableResourceTypes.MODEL_VERSION, + ) + existing_model_version.update( target_stage=stage, target_name=model_version_update_model.name, diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index b1ae0b2c52a..9e14b430a03 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -87,12 +87,33 @@ def test_model_exists(self, clean_client: "Client"): def test_model_create_model_and_version(self, clean_client: "Client"): """Test if model and version are created, not existing before.""" with ModelContext(clean_client, create_model=False): - mv = ModelVersion(name=MODEL_NAME) + mv = ModelVersion(name=MODEL_NAME, tags=["tag1", "tag2"]) with mock.patch("zenml.model.model_version.logger.info") as logger: mv = mv._get_or_create_model_version() logger.assert_called() assert mv.name == str(mv.number) assert mv.model.name == MODEL_NAME + assert {t.name for t in mv.tags} == {"tag1", "tag2"} + assert {t.name for t in mv.model.tags} == {"tag1", "tag2"} + + def test_create_model_version_makes_proper_tagging( + self, clean_client: "Client" + ): + """Test if model versions get unique tags.""" + with ModelContext(clean_client, create_model=False): + mv = ModelVersion(name=MODEL_NAME, tags=["tag1", "tag2"]) + mv = mv._get_or_create_model_version() + assert mv.name == str(mv.number) + assert mv.model.name == MODEL_NAME + assert {t.name for t in mv.tags} == {"tag1", "tag2"} + assert {t.name for t in mv.model.tags} == {"tag1", "tag2"} + + mv = ModelVersion(name=MODEL_NAME, tags=["tag3", "tag4"]) + mv = mv._get_or_create_model_version() + assert mv.name == str(mv.number) + assert mv.model.name == MODEL_NAME + assert {t.name for t in mv.tags} == {"tag3", "tag4"} + assert {t.name for t in mv.model.tags} == {"tag1", "tag2"} def test_model_fetch_model_and_version_by_number( self, clean_client: "Client" @@ -203,7 +224,7 @@ def test_tags_properly_updated(self, clean_client: "Client"): tags=["foo", "bar"], delete_new_version_on_failure=False, ) - model_id = mv._get_or_create_model().id + model_id = mv._get_or_create_model_version().model.id clean_client.update_model(model_id, add_tags=["tag1", "tag2"]) model = mv._get_or_create_model() @@ -215,7 +236,26 @@ def test_tags_properly_updated(self, clean_client: "Client"): "tag2", } + clean_client.update_model_version( + model_id, "1", add_tags=["tag3", "tag4"] + ) + model_version = mv._get_or_create_model_version() + assert len(model_version.tags) == 4 + assert {t.name for t in model_version.tags} == { + "foo", + "bar", + "tag3", + "tag4", + } + clean_client.update_model(model_id, remove_tags=["tag1", "tag2"]) model = mv._get_or_create_model() assert len(model.tags) == 2 assert {t.name for t in model.tags} == {"foo", "bar"} + + clean_client.update_model_version( + model_id, "1", remove_tags=["tag3", "tag4"] + ) + model_version = mv._get_or_create_model_version() + assert len(model_version.tags) == 2 + assert {t.name for t in model_version.tags} == {"foo", "bar"} diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index bda40f41573..0ece4e2922a 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -1352,6 +1352,14 @@ def test_create_model_version_pass(self, client_with_model: "Client"): assert model_version.number == 4 assert model_version.description == "some desc" + model_version = client_with_model.create_model_version( + self.MODEL_NAME, tags=["a", "b"] + ) + + assert model_version.name == "5" + assert model_version.number == 5 + assert {t.name for t in model_version.tags} == {"a", "b"} + def test_create_model_version_duplicate_fails( self, client_with_model: "Client" ):