Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Versions are taggable #2102

Merged
merged 12 commits into from
Dec 13, 2023
2 changes: 1 addition & 1 deletion examples/e2e/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 2023.11.23-3-g11d8368
_commit: 2023.11.24
_src_path: gh:zenml-io/template-e2e-batch
data_quality_checks: true
email: ''
Expand Down
69 changes: 69 additions & 0 deletions src/zenml/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,75 @@
zenml code-repository delete <REPOSITORY_NAME_OR_ID>
```

Administering your Models
----------------------------

ZenML provides several CLI commands to help you administer your models and
model versions as part of the Model Control Plane.

To register a new model, you can use the following CLI command:
```bash
zenml model register --name <NAME> [--MODEL_OPTIONS]
```

To list all registered models, use:
```bash
zenml model list
```

To update a model, use:
```bash
zenml model update <MODEL_NAME_OR_ID> [--MODEL_OPTIONS]
```

If you would like to add or remove tags from the model, use:
```bash
zenml model update <MODEL_NAME_OR_ID> --tag <TAG> --tag <TAG> ..
--remove-tag <TAG> --remove-tag <TAG> ..
```

To delete a model, use:
```bash
zenml model delete <MODEL_NAME_OR_ID>
```

The CLI interface for models also helps to navigate through artifacts linked to a specific model versions.
```bash
zenml model data_artifacts <MODEL_NAME_OR_ID> [-v <VERSION>]
zenml model endpoint_artifacts <MODEL_NAME_OR_ID> [-v <VERSION>]
zenml model model_artifacts <MODEL_NAME_OR_ID> [-v <VERSION>]
```

You can also navigate the pipeline runs linked to a specific model versions:
```bash
zenml model runs <MODEL_NAME_OR_ID> [-v <VERSION>]
```

To list the model versions of a specific model, use:
```bash
zenml model version list <MODEL_NAME_OR_ID>
```

To delete a model version, use:
```bash
zenml model version delete <MODEL_NAME_OR_ID> <VERSION>
```

To update a model version, use:
```bash
zenml model version update <MODEL_NAME_OR_ID> <VERSION> [--MODEL_VERSION_OPTIONS]
```
These are some of the more common uses of model version updates:
- stage (i.e. promotion)
```bash
zenml model version update <MODEL_NAME_OR_ID> <VERSION> --stage <STAGE>
```
- tags
```bash
zenml model version update <MODEL_NAME_OR_ID> <VERSION> --tag <TAG> --tag <TAG> ..
--remove-tag <TAG> --remove-tag <TAG> ..
```

Administering your Pipelines
----------------------------

Expand Down
33 changes: 28 additions & 5 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _model_version_to_print(
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"tags": [t.name for t in model_version.tags],
"data_artifacts_count": len(model_version.data_artifact_ids),
"model_artifacts_count": len(model_version.model_artifact_ids),
"endpoint_artifacts_count": len(model_version.endpoint_artifact_ids),
Expand Down Expand Up @@ -390,6 +391,22 @@ def list_model_versions(model_name_or_id: str, **kwargs: Any) -> None:
type=click.Choice(choices=ModelStages.values()),
help="The stage of the model version.",
)
@click.option(
"--tag",
"-t",
help="Tags to be added to the model.",
type=str,
required=False,
multiple=True,
)
@click.option(
"--remove-tag",
"-r",
help="Tags to be removed from the model.",
type=str,
required=False,
multiple=True,
)
@click.option(
"--force",
"-f",
Expand All @@ -400,6 +417,8 @@ def update_model_version(
model_name_or_id: str,
model_version_name_or_number_or_id: str,
stage: str,
tag: Optional[List[str]],
remove_tag: Optional[List[str]],
force: bool = False,
) -> None:
"""Update an existing model version stage in the Model Control Plane.
Expand All @@ -408,17 +427,21 @@ def update_model_version(
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The ID, number or name of the model version.
stage: The stage of the model version to be set.
tag: Tags to be added to the model version.
remove_tag: Tags to be removed from the model version.
force: Whether existing model version in target stage should be silently archived.
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
try:
Client().update_model_version(
model_version = Client().update_model_version(
model_name_or_id=model_name_or_id,
version_name_or_id=model_version.id,
stage=stage,
add_tags=tag,
remove_tags=remove_tag,
force=force,
)
except RuntimeError:
Expand All @@ -435,15 +458,15 @@ def update_model_version(
if not confirmation:
cli_utils.declare("Model version stage update canceled.")
return
Client().update_model_version(
model_version = Client().update_model_version(
model_name_or_id=model_version.model.id,
version_name_or_id=model_version.id,
stage=stage,
add_tags=tag,
remove_tags=remove_tag,
force=True,
)
cli_utils.declare(
f"Model version '{model_version.name}' stage updated to '{stage}'."
)
cli_utils.print_table([_model_version_to_print(model_version)])


@version.command("delete", help="Delete an existing model version.")
Expand Down
9 changes: 9 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4569,6 +4569,7 @@ def create_model_version(
model_name_or_id: Union[str, UUID],
name: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> ModelVersionResponse:
"""Creates a new model version in Model Control Plane.

Expand All @@ -4577,6 +4578,7 @@ def create_model_version(
version in.
name: the name of the Model Version to be created.
description: the description of the Model Version to be created.
tags: Tags associated with the model.

Returns:
The newly created model version.
Expand All @@ -4590,6 +4592,7 @@ def create_model_version(
user=self.active_user.id,
workspace=self.active_workspace.id,
model=model_name_or_id,
tags=tags,
)
)

Expand Down Expand Up @@ -4747,6 +4750,8 @@ def update_model_version(
stage: Optional[Union[str, ModelStages]] = None,
force: bool = False,
name: Optional[str] = None,
add_tags: Optional[List[str]] = None,
remove_tags: Optional[List[str]] = None,
) -> ModelVersionResponse:
"""Get all model versions by filter.

Expand All @@ -4757,6 +4762,8 @@ def update_model_version(
force: Whether existing model version in target stage should be
silently archived or an error should be raised.
name: Target model version name to be set.
add_tags: Tags to add to the model version.
remove_tags: Tags to remove from to the model version.

Returns:
An updated model version.
Expand All @@ -4775,6 +4782,8 @@ def update_model_version(
stage=stage,
force=force,
name=name,
add_tags=add_tags,
remove_tags=remove_tags,
),
)

Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class TaggableResourceTypes(StrEnum):
ARTIFACT = "artifact"
ARTIFACT_VERSION = "artifact_version"
MODEL = "model"
MODEL_VERSION = "model_version"


class ResponseUpdateStrategy(StrEnum):
Expand Down
1 change: 1 addition & 0 deletions src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def _get_or_create_model_version(self) -> "ModelVersionResponse":
name=self.version,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequest.parse_obj(model_version_request)
try:
Expand Down
31 changes: 28 additions & 3 deletions src/zenml/models/v2/core/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
"""Models representing model versions."""

from datetime import datetime
from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel, Field, PrivateAttr, validator

from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.enums import ModelStages
from zenml.models.tag_models import TagResponseModel
from zenml.models.v2.base.filter import AnyQuery
from zenml.models.v2.base.scoped import (
WorkspaceScopedFilter,
Expand Down Expand Up @@ -68,6 +69,9 @@ class ModelVersionRequest(WorkspaceScopedRequest):
model: UUID = Field(
description="The ID of the model containing version",
)
tags: Optional[List[str]] = Field(
title="Tags associated with the model version",
)


# ------------------ Update Model ------------------
Expand All @@ -88,7 +92,16 @@ class ModelVersionUpdate(BaseModel):
default=False,
)
name: Optional[str] = Field(
description="Target model version name to be set", default=None
description="Target model version name to be set",
default=None,
)
add_tags: Optional[List[str]] = Field(
description="Tags to be added to the model version",
default=None,
)
remove_tags: Optional[List[str]] = Field(
description="Tags to be removed from the model version",
default=None,
)

@validator("stage")
Expand Down Expand Up @@ -134,6 +147,9 @@ class ModelVersionResponseBody(WorkspaceScopedResponseBody):
description="Pipeline runs linked to the model version",
default={},
)
tags: List[TagResponseModel] = Field(
title="Tags associated with the model version", default=[]
)
created: datetime = Field(
title="The timestamp when this component was created."
)
Expand Down Expand Up @@ -246,6 +262,15 @@ def updated(self) -> datetime:
"""
return self.get_body().updated

@property
def tags(self) -> List["TagResponseModel"]:
"""The `tags` property.

Returns:
the value of the property.
"""
return self.get_body().tags

@property
def description(self) -> Optional[str]:
"""The `description` property.
Expand Down Expand Up @@ -293,7 +318,7 @@ def to_model_version(
limitations=self.model.limitations,
trade_offs=self.model.trade_offs,
ethics=self.model.ethics,
tags=[t.name for t in self.model.tags],
tags=[t.name for t in self.tags],
version=self.name,
was_created_in_this_run=was_created_in_this_run,
suppress_class_validation_warnings=suppress_class_validation_warnings,
Expand Down
9 changes: 9 additions & 0 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ class ModelVersionSchema(NamedSchema, table=True):
back_populates="model_version",
sa_relationship_kwargs={"cascade": "delete"},
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="model_version",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
overlaps="tags",
),
)

number: int = Field(sa_column=Column(INTEGER, nullable=False))
description: str = Field(sa_column=Column(TEXT, nullable=True))
Expand Down Expand Up @@ -331,6 +339,7 @@ def to_model(
data_artifact_ids=data_artifact_ids,
endpoint_artifact_ids=endpoint_artifact_ids,
pipeline_run_ids=pipeline_run_ids,
tags=[t.tag.to_model() for t in self.tags],
)

return ModelVersionResponse(
Expand Down
18 changes: 14 additions & 4 deletions src/zenml/zen_stores/schemas/tag_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
ArtifactSchema,
ArtifactVersionSchema,
)
from zenml.zen_stores.schemas.model_schemas import ModelSchema
from zenml.zen_stores.schemas.model_schemas import (
ModelSchema,
ModelVersionSchema,
)


class TagSchema(NamedSchema, table=True):
Expand Down Expand Up @@ -126,21 +129,28 @@ class TagResourceSchema(BaseSchema, table=True):
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
overlaps="tags,model,artifact_version",
overlaps="tags,model,artifact_version,model_version",
),
)
artifact_version: List["ArtifactVersionSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
overlaps="tags,model,artifact",
overlaps="tags,model,artifact,model_version",
),
)
model: List["ModelSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
overlaps="tags,artifact,artifact_version",
overlaps="tags,artifact,artifact_version,model_version",
),
)
model_version: List["ModelVersionSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
overlaps="tags,model,artifact,artifact_version",
),
)

Expand Down
Loading