Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add labels and annotations to with_overrides #3037

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.loggers import logger
from flytekit.models import common as _common_model
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.task import Resources as _resources_model
Expand Down Expand Up @@ -140,6 +141,8 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -221,6 +224,16 @@ def with_overrides(
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize

if labels is not None:
if not isinstance(labels, dict):
raise AssertionError("Labels should be specified as dict[str, str]")
self._metadata._labels = _common_model.Labels(values=labels)

if annotations is not None:
if not isinstance(annotations, dict):
raise AssertionError("Annotations should be specified as dict[str, str]")
self._metadata._annotations = _common_model.Annotations(values=annotations)

return self


Expand Down
18 changes: 18 additions & 0 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def __init__(
cacheable: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serializable: typing.Optional[bool] = None,
labels: typing.Optional[dict[str, str]] = None,
annotations: typing.Optional[dict[str, str]] = None,
):
"""
Defines extra information about the Node.
Expand All @@ -183,6 +185,8 @@ def __init__(
:param cacheable: Indicates that this nodes outputs should be cached.
:param cache_version: The version of the cached data.
:param cacheable: Indicates that cache operations on this node should be serialized.
:param labels: Identifying attributes to add to the k8s resource.
:param annotations: Arbitrary metadata to add to the k8s resource.
"""
self._name = name
self._timeout = timeout if timeout is not None else datetime.timedelta()
Expand All @@ -191,6 +195,8 @@ def __init__(
self._cacheable = cacheable
self._cache_version = cache_version
self._cache_serializable = cache_serializable
self._labels = labels
self._annotations = annotations

@property
def name(self):
Expand Down Expand Up @@ -229,6 +235,14 @@ def cache_version(self) -> typing.Optional[str]:
def cache_serializable(self) -> typing.Optional[bool]:
return self._cache_serializable

@property
def labels(self) -> typing.Optional[dict[str, str]]:
return self._labels

@property
def annotations(self) -> typing.Optional[dict[str, str]]:
return self._annotations

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.workflow_pb2.NodeMetadata
Expand All @@ -240,6 +254,8 @@ def to_flyte_idl(self):
cacheable=self.cacheable,
cache_version=self.cache_version,
cache_serializable=self.cache_serializable,
labels=self.labels,
annotations=self.annotations,
)
if self.timeout:
node_metadata.timeout.FromTimedelta(self.timeout)
Expand All @@ -255,6 +271,8 @@ def from_flyte_idl(cls, pb2_object):
pb2_object.cacheable if pb2_object.HasField("cacheable") else None,
pb2_object.cache_version if pb2_object.HasField("cache_version") else None,
pb2_object.cache_serializable if pb2_object.HasField("cache_serializable") else None,
pb2_object.labels if pb2_object.HasField("labels") else None,
pb2_object.annotations if pb2_object.HasField("annotations") else None,
)


Expand Down
22 changes: 22 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,25 @@ def my_wf(a: str) -> str:
assert wf_spec.template.nodes[0].metadata.cache_serializable
assert wf_spec.template.nodes[0].metadata.cacheable
assert wf_spec.template.nodes[0].metadata.cache_version == "foo"


def test_override_labels_annotations():
@task
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(labels={"override": "label"}, annotations={"override": "annotation"})

serialization_settings = flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)

assert wf_spec.template.nodes[0].metadata.labels.values == {"override": "label"}
assert wf_spec.template.nodes[0].metadata.annotations.values == {"override": "annotation"}
Loading