Skip to content

Commit

Permalink
Rename log_model_version_metadata to log_model_metadata (#2215)
Browse files Browse the repository at this point in the history
* Renamed to log_model_metadata

* Added to public

* Added older function just in case

* fix docstring darglint error

---------

Co-authored-by: Alex Strick van Linschoten <stricksubscriptions@fastmail.fm>
  • Loading branch information
htahir1 and strickvl authored Jan 4, 2024
1 parent 1e957ab commit 14fc0b8
Showing 3 changed files with 34 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/zenml/__init__.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@
save_artifact,
load_artifact,
)
from zenml.model.utils import log_model_metadata
from zenml.artifacts.artifact_config import ArtifactConfig
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.model.model_version import ModelVersion
@@ -57,6 +58,7 @@
"get_step_context",
"load_artifact",
"log_artifact_metadata",
"log_model_metadata",
"ModelVersion",
"pipeline",
"save_artifact",
27 changes: 27 additions & 0 deletions src/zenml/model/utils.py
Original file line number Diff line number Diff line change
@@ -125,6 +125,33 @@ def log_model_version_metadata(
This function can be used to log metadata for existing model versions.
Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.
"""
logger.warning(
"`log_model_version_metadata` is deprecated. Please use "
"`log_model_metadata` instead."
)
log_model_metadata(
metadata=metadata, model_name=model_name, model_version=model_version
)


def log_model_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.
This function can be used to log metadata for existing model versions.
Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
10 changes: 5 additions & 5 deletions tests/integration/functional/model/test_model_version.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from zenml.client import Client
from zenml.enums import ModelStages
from zenml.model.model_version import ModelVersion
from zenml.model.utils import log_model_version_metadata
from zenml.model.utils import log_model_metadata
from zenml.models import TagRequest

MODEL_NAME = "super_model"
@@ -64,7 +64,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
@step
def step_metadata_logging_functional():
"""Functional logging using implicit ModelVersion from context."""
log_model_version_metadata({"foo": "bar"})
log_model_metadata({"foo": "bar"})
assert get_step_context().model_version.metadata["foo"] == "bar"


@@ -357,17 +357,17 @@ def test_metadata_logging_functional(self, clean_client: "Client"):
)
mv._get_or_create_model_version()

log_model_version_metadata(
log_model_metadata(
{"foo": "bar"}, model_name=mv.name, model_version=mv.number
)

assert len(mv.metadata) == 1
assert mv.metadata["foo"] == "bar"

with pytest.raises(ValueError):
log_model_version_metadata({"foo": "bar"})
log_model_metadata({"foo": "bar"})

log_model_version_metadata(
log_model_metadata(
{"bar": "foo"}, model_name=mv.name, model_version="latest"
)

0 comments on commit 14fc0b8

Please sign in to comment.