From 76ed5537567e37a07c691cbb6c2eea9932fcb7c9 Mon Sep 17 00:00:00 2001 From: Pavel Zhukov <33721692+LeaveMyYard@users.noreply.github.com> Date: Thu, 4 Apr 2024 13:41:07 +0300 Subject: [PATCH] Refactor k8s workloads streaming (#256) * Refactor k8s workloads streaming * Fix tests --- .../core/integrations/kubernetes/__init__.py | 97 +++++++++---------- robusta_krr/core/runner.py | 8 +- robusta_krr/utils/async_gen_merge.py | 39 -------- tests/conftest.py | 13 +-- 4 files changed, 51 insertions(+), 106 deletions(-) delete mode 100644 robusta_krr/utils/async_gen_merge.py diff --git a/robusta_krr/core/integrations/kubernetes/__init__.py b/robusta_krr/core/integrations/kubernetes/__init__.py index 335b47af..a772a5c2 100644 --- a/robusta_krr/core/integrations/kubernetes/__init__.py +++ b/robusta_krr/core/integrations/kubernetes/__init__.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, Iterable, Optional, Union +from typing import Any, Awaitable, Callable, Iterable, Optional, Union from kubernetes import client, config # type: ignore from kubernetes.client import ApiException @@ -20,7 +20,6 @@ from robusta_krr.core.models.config import settings from robusta_krr.core.models.objects import HPAData, K8sObjectData, KindLiteral, PodData from robusta_krr.core.models.result import ResourceAllocations -from robusta_krr.utils.async_gen_merge import async_gen_merge from robusta_krr.utils.object_like_dict import ObjectLikeDict from . import config_patch as _ @@ -49,7 +48,7 @@ def __init__(self, cluster: Optional[str]=None): self.__jobs_for_cronjobs: dict[str, list[V1Job]] = {} self.__jobs_loading_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - async def list_scannable_objects(self) -> AsyncGenerator[K8sObjectData, None]: + async def list_scannable_objects(self) -> list[K8sObjectData]: """List all scannable objects. Returns: @@ -61,10 +60,7 @@ async def list_scannable_objects(self) -> AsyncGenerator[K8sObjectData, None]: logger.debug(f"Resources: {settings.resources}") self.__hpa_list = await self._try_list_hpa() - - # https://stackoverflow.com/questions/55299564/join-multiple-async-generators-in-python - # This will merge all the streams from all the cluster loaders into a single stream - async for object in async_gen_merge( + workload_object_lists = await asyncio.gather( self._list_deployments(), self._list_rollouts(), self._list_deploymentconfig(), @@ -72,11 +68,15 @@ async def list_scannable_objects(self) -> AsyncGenerator[K8sObjectData, None]: self._list_all_daemon_set(), self._list_all_jobs(), self._list_all_cronjobs(), - ): + ) + + return [ + object + for workload_objects in workload_object_lists + for object in workload_objects # NOTE: By default we will filter out kube-system namespace - if settings.namespaces == "*" and object.namespace == "kube-system": - continue - yield object + if not (settings.namespaces == "*" and object.namespace == "kube-system") + ] async def _list_jobs_for_cronjobs(self, namespace: str) -> list[V1Job]: if namespace not in self.__jobs_for_cronjobs: @@ -185,12 +185,12 @@ async def _list_namespaced_or_global_objects( kind: KindLiteral, all_namespaces_request: Callable, namespaced_request: Callable - ) -> AsyncIterable[Any]: + ) -> list[Any]: logger.debug(f"Listing {kind}s in {self.cluster}") loop = asyncio.get_running_loop() if settings.namespaces == "*": - tasks = [ + requests = [ loop.run_in_executor( self.executor, lambda: all_namespaces_request( @@ -200,7 +200,7 @@ async def _list_namespaced_or_global_objects( ) ] else: - tasks = [ + requests = [ loop.run_in_executor( self.executor, lambda ns=namespace: namespaced_request( @@ -212,14 +212,14 @@ async def _list_namespaced_or_global_objects( for namespace in settings.namespaces ] - total_items = 0 - for task in asyncio.as_completed(tasks): - ret_single = await task - total_items += len(ret_single.items) - for item in ret_single.items: - yield item + result = [ + item + for request_result in await asyncio.gather(*requests) + for item in request_result.items + ] - logger.debug(f"Found {total_items} {kind} in {self.cluster}") + logger.debug(f"Found {len(result)} {kind} in {self.cluster}") + return result async def _list_scannable_objects( self, @@ -228,16 +228,17 @@ async def _list_scannable_objects( namespaced_request: Callable, extract_containers: Callable[[Any], Union[Iterable[V1Container], Awaitable[Iterable[V1Container]]]], filter_workflows: Optional[Callable[[Any], bool]] = None, - ) -> AsyncIterable[K8sObjectData]: + ) -> list[K8sObjectData]: if not self._should_list_resource(kind): logger.debug(f"Skipping {kind}s in {self.cluster}") return if not self.__kind_available[kind]: return - + + result = [] try: - async for item in self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request): + for item in await self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request): if filter_workflows is not None and not filter_workflows(item): continue @@ -245,8 +246,7 @@ async def _list_scannable_objects( if asyncio.iscoroutine(containers): containers = await containers - for container in containers: - yield self.__build_scannable_object(item, container, kind) + result.extend(self.__build_scannable_object(item, container, kind) for container in containers) except ApiException as e: if kind in ("Rollout", "DeploymentConfig") and e.status in [400, 401, 403, 404]: if self.__kind_available[kind]: @@ -256,7 +256,9 @@ async def _list_scannable_objects( logger.exception(f"Error {e.status} listing {kind} in cluster {self.cluster}: {e.reason}") logger.error("Will skip this object type and continue.") - def _list_deployments(self) -> AsyncIterable[K8sObjectData]: + return result + + def _list_deployments(self) -> list[K8sObjectData]: return self._list_scannable_objects( kind="Deployment", all_namespaces_request=self.apps.list_deployment_for_all_namespaces, @@ -264,7 +266,7 @@ def _list_deployments(self) -> AsyncIterable[K8sObjectData]: extract_containers=lambda item: item.spec.template.spec.containers, ) - def _list_rollouts(self) -> AsyncIterable[K8sObjectData]: + def _list_rollouts(self) -> list[K8sObjectData]: async def _extract_containers(item: Any) -> list[V1Container]: if item.spec.template is not None: return item.spec.template.spec.containers @@ -311,7 +313,7 @@ async def _extract_containers(item: Any) -> list[V1Container]: extract_containers=_extract_containers, ) - def _list_deploymentconfig(self) -> AsyncIterable[K8sObjectData]: + def _list_deploymentconfig(self) -> list[K8sObjectData]: # NOTE: Using custom objects API returns dicts, but all other APIs return objects # We need to handle this difference using a small wrapper return self._list_scannable_objects( @@ -335,7 +337,7 @@ def _list_deploymentconfig(self) -> AsyncIterable[K8sObjectData]: extract_containers=lambda item: item.spec.template.spec.containers, ) - def _list_all_statefulsets(self) -> AsyncIterable[K8sObjectData]: + def _list_all_statefulsets(self) -> list[K8sObjectData]: return self._list_scannable_objects( kind="StatefulSet", all_namespaces_request=self.apps.list_stateful_set_for_all_namespaces, @@ -343,7 +345,7 @@ def _list_all_statefulsets(self) -> AsyncIterable[K8sObjectData]: extract_containers=lambda item: item.spec.template.spec.containers, ) - def _list_all_daemon_set(self) -> AsyncIterable[K8sObjectData]: + def _list_all_daemon_set(self) -> list[K8sObjectData]: return self._list_scannable_objects( kind="DaemonSet", all_namespaces_request=self.apps.list_daemon_set_for_all_namespaces, @@ -351,7 +353,7 @@ def _list_all_daemon_set(self) -> AsyncIterable[K8sObjectData]: extract_containers=lambda item: item.spec.template.spec.containers, ) - def _list_all_jobs(self) -> AsyncIterable[K8sObjectData]: + def _list_all_jobs(self) -> list[K8sObjectData]: return self._list_scannable_objects( kind="Job", all_namespaces_request=self.batch.list_job_for_all_namespaces, @@ -363,7 +365,7 @@ def _list_all_jobs(self) -> AsyncIterable[K8sObjectData]: ), ) - def _list_all_cronjobs(self) -> AsyncIterable[K8sObjectData]: + def _list_all_cronjobs(self) -> list[K8sObjectData]: return self._list_scannable_objects( kind="CronJob", all_namespaces_request=self.batch.list_cron_job_for_all_namespaces, @@ -398,14 +400,10 @@ async def __list_hpa_v1(self) -> dict[HPAKey, HPAData]: } async def __list_hpa_v2(self) -> dict[HPAKey, HPAData]: - loop = asyncio.get_running_loop() - res = await loop.run_in_executor( - self.executor, - lambda: self._list_namespaced_or_global_objects( - kind="HPA-v2", - all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces, - namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler, - ), + res = await self._list_namespaced_or_global_objects( + kind="HPA-v2", + all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces, + namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler, ) def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[float]: return next( @@ -429,7 +427,7 @@ def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[f target_cpu_utilization_percentage=__get_metric(hpa, "cpu"), target_memory_utilization_percentage=__get_metric(hpa, "memory"), ) - async for hpa in res + for hpa in res } # TODO: What should we do in case of other metrics bound to the HPA? @@ -514,7 +512,7 @@ def _try_create_cluster_loader(self, cluster: Optional[str]) -> Optional[Cluster logger.error(f"Could not load cluster {cluster} and will skip it: {e}") return None - async def list_scannable_objects(self, clusters: Optional[list[str]]) -> AsyncIterable[K8sObjectData]: + async def list_scannable_objects(self, clusters: Optional[list[str]]) -> list[K8sObjectData]: """List all scannable objects. Yields: @@ -529,13 +527,12 @@ async def list_scannable_objects(self, clusters: Optional[list[str]]) -> AsyncIt if self.cluster_loaders == {}: logger.error("Could not load any cluster.") return - - # https://stackoverflow.com/questions/55299564/join-multiple-async-generators-in-python - # This will merge all the streams from all the cluster loaders into a single stream - async for object in async_gen_merge( - *[cluster_loader.list_scannable_objects() for cluster_loader in self.cluster_loaders.values()] - ): - yield object + + return [ + object + for cluster_loader in self.cluster_loaders.values() + for object in await cluster_loader.list_scannable_objects() + ] async def load_pods(self, object: K8sObjectData) -> list[PodData]: try: diff --git a/robusta_krr/core/runner.py b/robusta_krr/core/runner.py index 546dd013..8e08521c 100644 --- a/robusta_krr/core/runner.py +++ b/robusta_krr/core/runner.py @@ -275,12 +275,8 @@ async def _collect_result(self) -> Result: await asyncio.gather(*[self._check_data_availability(cluster) for cluster in clusters]) with ProgressBar(title="Calculating Recommendation") as self.__progressbar: - scans_tasks = [ - asyncio.create_task(self._gather_object_allocations(k8s_object)) - async for k8s_object in self._k8s_loader.list_scannable_objects(clusters) - ] - - scans = await asyncio.gather(*scans_tasks) + workloads = await self._k8s_loader.list_scannable_objects(clusters) + scans = await asyncio.gather(*[self._gather_object_allocations(k8s_object) for k8s_object in workloads]) successful_scans = [scan for scan in scans if scan is not None] diff --git a/robusta_krr/utils/async_gen_merge.py b/robusta_krr/utils/async_gen_merge.py deleted file mode 100644 index 35c2c866..00000000 --- a/robusta_krr/utils/async_gen_merge.py +++ /dev/null @@ -1,39 +0,0 @@ -import asyncio -import logging -from typing import AsyncIterable, TypeVar - - -logger = logging.getLogger("krr") - - -# Define a type variable for the values yielded by the async generators -T = TypeVar("T") - - -def async_gen_merge(*aiters: AsyncIterable[T]) -> AsyncIterable[T]: - queue = asyncio.Queue() - iters_remaining = set(aiters) - - async def drain(aiter): - try: - async for item in aiter: - await queue.put(item) - except Exception: - logger.exception(f"Error in async generator {aiter}") - finally: - iters_remaining.discard(aiter) - await queue.put(None) - - async def merged(): - while iters_remaining or not queue.empty(): - item = await queue.get() - - if item is None: - continue - - yield item - - for aiter in aiters: - asyncio.create_task(drain(aiter)) - - return merged() diff --git a/tests/conftest.py b/tests/conftest.py index 61c389dd..b1d8d228 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import random from datetime import datetime, timedelta -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import numpy as np import pytest @@ -26,15 +26,6 @@ ) -class AsyncIter: - def __init__(self, items): - self.items = items - - async def __aiter__(self): - for item in self.items: - yield item - - @pytest.fixture(autouse=True, scope="session") def mock_list_clusters(): with patch( @@ -48,7 +39,7 @@ def mock_list_clusters(): def mock_list_scannable_objects(): with patch( "robusta_krr.core.integrations.kubernetes.KubernetesLoader.list_scannable_objects", - new=MagicMock(return_value=AsyncIter([TEST_OBJECT])), + new=AsyncMock(return_value=[TEST_OBJECT]), ): yield