From c00098bf6c54392f86a1398bce02ea71c45d9904 Mon Sep 17 00:00:00 2001 From: Matteo Mortari Date: Fri, 16 Feb 2024 08:58:46 +0100 Subject: [PATCH] py: default metadata capture environment vars (#307) * py: default metadata capture environment vars * linting --- clients/python/README.md | 5 ++ clients/python/src/model_registry/_client.py | 19 +++++- clients/python/tests/test_client.py | 67 +++++++++++++++++++- 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/clients/python/README.md b/clients/python/README.md index f088f844..febca492 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -37,6 +37,11 @@ version = registry.get_model_version("my-model", "v2.0") experiment = registry.get_model_artifact("my-model", "v2.0") ``` +### Default values for metadata + +If not supplied, `metadata` values defaults to a predefined set of conventional values. +Reference the technical documentation in the pydoc of the client. + ### Importing from Hugging Face Hub To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index a4d5bafc..826f061f 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -1,6 +1,7 @@ """Standard client for the model registry.""" from __future__ import annotations +import os from typing import get_args from warnings import warn @@ -98,7 +99,7 @@ def register_model( storage_key: Storage key. storage_path: Storage path. service_account_name: Service account name. - metadata: Additional version metadata. + metadata: Additional version metadata. Defaults to values returned by `default_metadata()`. Returns: Registered model. @@ -109,7 +110,7 @@ def register_model( version, author or self._author, description=description, - metadata=metadata or {}, + metadata=metadata or self.default_metadata(), ) self._register_model_artifact( mv, @@ -123,6 +124,19 @@ def register_model( return rm + def default_metadata(self) -> dict[str, ScalarType]: + """Default metadata valorisations. + + When not explicitly supplied by the end users, these valorisations will be used + by default. + + Returns: + default metadata valorisations. + """ + return { + key: os.environ[key] for key in ["AWS_S3_ENDPOINT", "AWS_S3_BUCKET", "AWS_DEFAULT_REGION"] if key in os.environ + } + def register_hf_model( self, repo: str, @@ -188,6 +202,7 @@ def register_hf_model( model_author = author source_uri = hf_hub_url(repo, path, revision=git_ref) metadata = { + **self.default_metadata(), "repo": repo, "source_uri": source_uri, "model_origin": "huggingface_hub", diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 7dee53ca..9e33b07b 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -1,3 +1,5 @@ +import os + import pytest from model_registry import ModelRegistry from model_registry.core import ModelRegistryAPIClient @@ -46,6 +48,7 @@ def test_register_existing_version(mr_client: ModelRegistry): def test_get(mr_client: ModelRegistry): name = "test_model" version = "1.0.0" + metadata = {"a": 1, "b": "2"} rm = mr_client.register_model( name, @@ -53,6 +56,7 @@ def test_get(mr_client: ModelRegistry): model_format_name="test_format", model_format_version="test_version", version=version, + metadata=metadata ) assert (_rm := mr_client.get_registered_model(name)) @@ -64,22 +68,81 @@ def test_get(mr_client: ModelRegistry): assert (_mv := mr_client.get_model_version(name, version)) assert mv.id == _mv.id + assert mv.metadata == metadata assert (_ma := mr_client.get_model_artifact(name, version)) assert ma.id == _ma.id +def test_default_md(mr_client: ModelRegistry): + name = "test_model" + version = "1.0.0" + env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"} + for k, v in env_values.items(): + os.environ[k] = v + + assert mr_client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version=version, + # ensure leave empty metadata + ) + assert (mv := mr_client.get_model_version(name, version)) + assert mv.metadata == env_values + + for k in env_values: + os.environ.pop(k) + + def test_hf_import(mr_client: ModelRegistry): pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" version = "1.2.3" + author = "test author" + + assert mr_client.register_hf_model( + name, + "onnx/decoder_model.onnx", + author=author, + version=version, + model_format_name="test format", + model_format_version="test version", + ) + assert (mv := mr_client.get_model_version(name, version)) + assert mv.author == author + assert mv.metadata["model_author"] == author + assert mv.metadata["model_origin"] == "huggingface_hub" + assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + assert mv.metadata["repo"] == name + assert mr_client.get_model_artifact(name, version) + + +def test_hf_import_default_env(mr_client: ModelRegistry): + """Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata + """ + pytest.importorskip("huggingface_hub") + name = "openai-community/gpt2" + version = "1.2.3" + author = "test author" + env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"} + for k, v in env_values.items(): + os.environ[k] = v assert mr_client.register_hf_model( name, "onnx/decoder_model.onnx", - author="test author", + author=author, version=version, model_format_name="test format", model_format_version="test version", ) - assert mr_client.get_model_version(name, version) + assert (mv := mr_client.get_model_version(name, version)) + assert mv.metadata["model_author"] == author + assert mv.metadata["model_origin"] == "huggingface_hub" + assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + assert mv.metadata["repo"] == name assert mr_client.get_model_artifact(name, version) + + for k in env_values: + os.environ.pop(k)