diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 5fa31652bd6..3cdd2d95765 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -42,6 +42,7 @@ save_artifact, load_artifact, ) +from zenml.model.utils import log_model_metadata from zenml.artifacts.artifact_config import ArtifactConfig from zenml.artifacts.external_artifact import ExternalArtifact from zenml.model.model_version import ModelVersion @@ -57,6 +58,7 @@ "get_step_context", "load_artifact", "log_artifact_metadata", + "log_model_metadata", "ModelVersion", "pipeline", "save_artifact", diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 164b8fdd035..f7e71c22c7b 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -125,6 +125,33 @@ def log_model_version_metadata( This function can be used to log metadata for existing model versions. + Args: + metadata: The metadata to log. + model_name: The name of the model to log metadata for. Can + be omitted when being called inside a step with configured + `model_version` in decorator. + model_version: The version of the model to log metadata for. Can + be omitted when being called inside a step with configured + `model_version` in decorator. + """ + logger.warning( + "`log_model_version_metadata` is deprecated. Please use " + "`log_model_metadata` instead." + ) + log_model_metadata( + metadata=metadata, model_name=model_name, model_version=model_version + ) + + +def log_model_metadata( + metadata: Dict[str, "MetadataType"], + model_name: Optional[str] = None, + model_version: Optional[Union[ModelStages, int, str]] = None, +) -> None: + """Log model version metadata. + + This function can be used to log metadata for existing model versions. + Args: metadata: The metadata to log. model_name: The name of the model to log metadata for. Can diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index 6026c6f900f..83150fe4991 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -20,7 +20,7 @@ from zenml.client import Client from zenml.enums import ModelStages from zenml.model.model_version import ModelVersion -from zenml.model.utils import log_model_version_metadata +from zenml.model.utils import log_model_metadata from zenml.models import TagRequest MODEL_NAME = "super_model" @@ -64,7 +64,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): @step def step_metadata_logging_functional(): """Functional logging using implicit ModelVersion from context.""" - log_model_version_metadata({"foo": "bar"}) + log_model_metadata({"foo": "bar"}) assert get_step_context().model_version.metadata["foo"] == "bar" @@ -357,7 +357,7 @@ def test_metadata_logging_functional(self, clean_client: "Client"): ) mv._get_or_create_model_version() - log_model_version_metadata( + log_model_metadata( {"foo": "bar"}, model_name=mv.name, model_version=mv.number ) @@ -365,9 +365,9 @@ def test_metadata_logging_functional(self, clean_client: "Client"): assert mv.metadata["foo"] == "bar" with pytest.raises(ValueError): - log_model_version_metadata({"foo": "bar"}) + log_model_metadata({"foo": "bar"}) - log_model_version_metadata( + log_model_metadata( {"bar": "foo"}, model_name=mv.name, model_version="latest" )