Skip to content

Commit

Permalink
Added more pod spec properties for k8s orchestrator (#2097)
Browse files Browse the repository at this point in the history
* Added more pod spec properties for k8s orchestrator

* Added more pod spec properties for k8s orchestrator

* Applied to kubeflow too

* Pod settings

* Linting

* Added volume mounts and volumes to tests

* Added volume mounts and volumes to tests

* Took care of security warnings

* Took care of security warnings

* Took care of security warnings

* Took care of security warnings

* Auto-update of Starter template

* Auto-update of Starter template

* Auto-update of E2E template

* Auto-update of NLP template

---------

Co-authored-by: GitHub Actions <actions@github.com>
  • Loading branch information
htahir1 and actions-user authored Jan 2, 2024
1 parent 8592869 commit 88f75c0
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 2 deletions.
32 changes: 31 additions & 1 deletion src/zenml/integrations/kubeflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from zenml.integrations.kubernetes import serialization_utils
from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
from zenml.logger import get_logger

logger = get_logger(__name__)

if TYPE_CHECKING:
from kfp.dsl import ContainerOp
Expand All @@ -33,7 +36,18 @@ def apply_pod_settings(
container_op: The container to which to apply the settings.
settings: The settings to apply.
"""
from kubernetes.client.models import V1Affinity, V1Toleration
from kubernetes.client.models import (
V1Affinity,
V1Toleration,
V1Volume,
V1VolumeMount,
)

if settings.host_ipc:
logger.warning(
"Host IPC is set to `True` but not supported in this orchestrator. "
"Ignoring..."
)

for key, value in settings.node_selectors.items():
container_op.add_node_selector_constraint(label_name=key, value=value)
Expand All @@ -54,6 +68,22 @@ def apply_pod_settings(
)
container_op.add_toleration(toleration)

if settings.volumes:
for v in settings.volumes:
volume: (
V1Volume
) = serialization_utils.deserialize_kubernetes_model(v, "V1Volume")
container_op.add_volume(volume)

if settings.volume_mounts:
for v in settings.volume_mounts:
volume_mount: (
V1VolumeMount
) = serialization_utils.deserialize_kubernetes_model(
v, "V1VolumeMount"
)
container_op.container.add_volume_mount(volume_mount)

resource_requests = settings.resources.get("requests") or {}
for name, value in resource_requests.items():
container_op.add_resource_request(name, value)
Expand Down
14 changes: 14 additions & 0 deletions src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def add_pod_settings(
for container in pod_spec.containers:
assert isinstance(container, k8s_client.V1Container)
container._resources = settings.resources
if settings.volume_mounts:
if container.volume_mounts:
container.volume_mounts.extend(settings.volume_mounts)
else:
container.volume_mounts = settings.volume_mounts

if settings.volumes:
if pod_spec.volumes:
pod_spec.volumes.extend(settings.volumes)
else:
pod_spec.volumes = settings.volumes

if settings.host_ipc:
pod_spec.host_ipc = settings.host_ipc


def build_cron_job_manifest(
Expand Down
58 changes: 58 additions & 0 deletions src/zenml/integrations/kubernetes/pod_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
V1Affinity,
V1ResourceRequirements,
V1Toleration,
V1Volume,
V1VolumeMount,
)


Expand All @@ -37,13 +39,69 @@ class KubernetesPodSettings(BaseSettings):
tolerations: Tolerations to apply to the pod.
resources: Resource requests and limits for the pod.
annotations: Annotations to apply to the pod metadata.
volumes: Volumes to mount in the pod.
volume_mounts: Volume mounts to apply to the pod containers.
host_ipc: Whether to enable host IPC for the pod.
"""

node_selectors: Dict[str, str] = {}
affinity: Dict[str, Any] = {}
tolerations: List[Dict[str, Any]] = []
resources: Dict[str, Dict[str, str]] = {}
annotations: Dict[str, str] = {}
volumes: List[Dict[str, Any]] = []
volume_mounts: List[Dict[str, Any]] = []
host_ipc: bool = False

@validator("volumes", pre=True)
def _convert_volumes(
cls, value: List[Union[Dict[str, Any], "V1Volume"]]
) -> List[Dict[str, Any]]:
"""Converts Kubernetes volumes to dicts.
Args:
value: The volumes list.
Returns:
The converted volumes.
"""
from kubernetes.client.models import V1Volume

result = []
for element in value:
if isinstance(element, V1Volume):
result.append(
serialization_utils.serialize_kubernetes_model(element)
)
else:
result.append(element)

return result

@validator("volume_mounts", pre=True)
def _convert_volume_mounts(
cls, value: List[Union[Dict[str, Any], "V1VolumeMount"]]
) -> List[Dict[str, Any]]:
"""Converts Kubernetes volume mounts to dicts.
Args:
value: The volume mounts list.
Returns:
The converted volume mounts.
"""
from kubernetes.client.models import V1VolumeMount

result = []
for element in value:
if isinstance(element, V1VolumeMount):
result.append(
serialization_utils.serialize_kubernetes_model(element)
)
else:
result.append(element)

return result

@validator("affinity", pre=True)
def _convert_affinity(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,19 @@ def _create_test_models() -> List[Any]:
key="key", operator="Equal", value="value", effect="NoExecute"
)

return [affinity, toleration]
volume = k8s_client.V1Volume(
name="cache-volume",
empty_dir=k8s_client.V1EmptyDirVolumeSource(
medium="Memory", size_limit="1Gi"
),
)

volume_mount = k8s_client.V1VolumeMount(
mount_path="/dev/shm", # nosec
name="cache-volume",
)

return [affinity, toleration, volume, volume_mount]


@pytest.mark.parametrize("model", _create_test_models())
Expand Down

0 comments on commit 88f75c0

Please sign in to comment.