From 88901304b83061c7c0da5c2a605d9d62e05a32bf Mon Sep 17 00:00:00 2001 From: pryce-turner Date: Mon, 6 Jan 2025 11:44:16 -0800 Subject: [PATCH 1/3] WIP Signed-off-by: pryce-turner --- flytekit/core/node.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index ea089c6fd3..0be34d1d44 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -140,6 +140,8 @@ def with_overrides( cache: Optional[bool] = None, cache_version: Optional[str] = None, cache_serialize: Optional[bool] = None, + labels: Optional[Dict[str, str]] = None, + annotations: Optional[Dict[str, str]] = None, *args, **kwargs, ): @@ -221,6 +223,18 @@ def with_overrides( assert_not_promise(cache_serialize, "cache_serialize") self._metadata._cache_serializable = cache_serialize + if labels is not None: + if not isinstance(labels, dict): + raise AssertionError("Labels should be specified as dict[str, str]") + for k, v in labels.items(): + self._metadata._labels.append(_workflow_model.Label(var=k, label=v)) + + if annotations is not None: + if not isinstance(annotations, dict): + raise AssertionError("Annotations should be specified as dict[str, str]") + for k, v in annotations.items(): + self._metadata.__annotations.append(_workflow_model.Annotation(var=k, annotation=v)) + return self From f4f674bda29cc90820ec7501ecdbb1e318e74f9b Mon Sep 17 00:00:00 2001 From: pryce-turner Date: Mon, 6 Jan 2025 13:45:39 -0800 Subject: [PATCH 2/3] Added correct models for labels and anno Signed-off-by: pryce-turner --- flytekit/core/node.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 0be34d1d44..aea378b321 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -11,6 +11,7 @@ from flytekit.extras.accelerators import BaseAccelerator from flytekit.loggers import logger from flytekit.models import literals as _literal_models +from flytekit.models.admin import common as _common_model from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model @@ -226,14 +227,12 @@ def with_overrides( if labels is not None: if not isinstance(labels, dict): raise AssertionError("Labels should be specified as dict[str, str]") - for k, v in labels.items(): - self._metadata._labels.append(_workflow_model.Label(var=k, label=v)) + self._metadata._labels.append(_common_model.Label(values=labels)) if annotations is not None: if not isinstance(annotations, dict): raise AssertionError("Annotations should be specified as dict[str, str]") - for k, v in annotations.items(): - self._metadata.__annotations.append(_workflow_model.Annotation(var=k, annotation=v)) + self._metadata._annotations.append(_common_model.Annotation(values=annotations)) return self From a2db6fd3b9ddbc4ccdd1e2f27a82cb8a3d87f61d Mon Sep 17 00:00:00 2001 From: pryce-turner Date: Mon, 6 Jan 2025 15:19:24 -0800 Subject: [PATCH 3/3] Added tests and fixed NodeMetadata Signed-off-by: pryce-turner --- flytekit/core/node.py | 6 ++--- flytekit/models/core/workflow.py | 18 +++++++++++++++ .../flytekit/unit/core/test_node_creation.py | 22 +++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index aea378b321..b7d76131e0 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -10,8 +10,8 @@ from flytekit.core.utils import _dnsify from flytekit.extras.accelerators import BaseAccelerator from flytekit.loggers import logger +from flytekit.models import common as _common_model from flytekit.models import literals as _literal_models -from flytekit.models.admin import common as _common_model from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model @@ -227,12 +227,12 @@ def with_overrides( if labels is not None: if not isinstance(labels, dict): raise AssertionError("Labels should be specified as dict[str, str]") - self._metadata._labels.append(_common_model.Label(values=labels)) + self._metadata._labels = _common_model.Labels(values=labels) if annotations is not None: if not isinstance(annotations, dict): raise AssertionError("Annotations should be specified as dict[str, str]") - self._metadata._annotations.append(_common_model.Annotation(values=annotations)) + self._metadata._annotations = _common_model.Annotations(values=annotations) return self diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 8d8bf9c9ef..e7c24e558c 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -172,6 +172,8 @@ def __init__( cacheable: typing.Optional[bool] = None, cache_version: typing.Optional[str] = None, cache_serializable: typing.Optional[bool] = None, + labels: typing.Optional[dict[str, str]] = None, + annotations: typing.Optional[dict[str, str]] = None, ): """ Defines extra information about the Node. @@ -183,6 +185,8 @@ def __init__( :param cacheable: Indicates that this nodes outputs should be cached. :param cache_version: The version of the cached data. :param cacheable: Indicates that cache operations on this node should be serialized. + :param labels: Identifying attributes to add to the k8s resource. + :param annotations: Arbitrary metadata to add to the k8s resource. """ self._name = name self._timeout = timeout if timeout is not None else datetime.timedelta() @@ -191,6 +195,8 @@ def __init__( self._cacheable = cacheable self._cache_version = cache_version self._cache_serializable = cache_serializable + self._labels = labels + self._annotations = annotations @property def name(self): @@ -229,6 +235,14 @@ def cache_version(self) -> typing.Optional[str]: def cache_serializable(self) -> typing.Optional[bool]: return self._cache_serializable + @property + def labels(self) -> typing.Optional[dict[str, str]]: + return self._labels + + @property + def annotations(self) -> typing.Optional[dict[str, str]]: + return self._annotations + def to_flyte_idl(self): """ :rtype: flyteidl.core.workflow_pb2.NodeMetadata @@ -240,6 +254,8 @@ def to_flyte_idl(self): cacheable=self.cacheable, cache_version=self.cache_version, cache_serializable=self.cache_serializable, + labels=self.labels, + annotations=self.annotations, ) if self.timeout: node_metadata.timeout.FromTimedelta(self.timeout) @@ -255,6 +271,8 @@ def from_flyte_idl(cls, pb2_object): pb2_object.cacheable if pb2_object.HasField("cacheable") else None, pb2_object.cache_version if pb2_object.HasField("cache_version") else None, pb2_object.cache_serializable if pb2_object.HasField("cache_serializable") else None, + pb2_object.labels if pb2_object.HasField("labels") else None, + pb2_object.annotations if pb2_object.HasField("annotations") else None, ) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 381f456bdb..041d70bcdb 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -518,3 +518,25 @@ def my_wf(a: str) -> str: assert wf_spec.template.nodes[0].metadata.cache_serializable assert wf_spec.template.nodes[0].metadata.cacheable assert wf_spec.template.nodes[0].metadata.cache_version == "foo" + + +def test_override_labels_annotations(): + @task + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(labels={"override": "label"}, annotations={"override": "annotation"}) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + + assert wf_spec.template.nodes[0].metadata.labels.values == {"override": "label"} + assert wf_spec.template.nodes[0].metadata.annotations.values == {"override": "annotation"}