From 4554d739b932b2a36732bb247626f67a88e097b3 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 18 Nov 2024 15:46:10 +0100 Subject: [PATCH] Fix service connector type registry race conditions --- .../service_connector_registry.py | 125 ++++++++++-------- 1 file changed, 68 insertions(+), 57 deletions(-) diff --git a/src/zenml/service_connectors/service_connector_registry.py b/src/zenml/service_connectors/service_connector_registry.py index e0801f2f1dc..47386371a24 100644 --- a/src/zenml/service_connectors/service_connector_registry.py +++ b/src/zenml/service_connectors/service_connector_registry.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Implementation of a service connector registry.""" +import threading from typing import TYPE_CHECKING, Dict, List, Optional, Union from zenml.logger import get_logger @@ -34,6 +35,7 @@ def __init__(self) -> None: """Initialize the service connector registry.""" self.service_connector_types: Dict[str, ServiceConnectorTypeModel] = {} self.initialized = False + self.lock = threading.RLock() def register_service_connector_type( self, @@ -46,23 +48,25 @@ def register_service_connector_type( service_connector_type: Service connector type. overwrite: Whether to overwrite an existing service connector type. """ - if ( - service_connector_type.connector_type - not in self.service_connector_types - or overwrite - ): - self.service_connector_types[ + with self.lock: + if ( service_connector_type.connector_type - ] = service_connector_type - logger.debug( - "Registered service connector type " - f"{service_connector_type.connector_type}." - ) - else: - logger.debug( - f"Found existing service connector for type " - f"{service_connector_type.connector_type}: Skipping registration." - ) + not in self.service_connector_types + or overwrite + ): + self.service_connector_types[ + service_connector_type.connector_type + ] = service_connector_type + logger.debug( + "Registered service connector type " + f"{service_connector_type.connector_type}." + ) + else: + logger.debug( + f"Found existing service connector for type " + f"{service_connector_type.connector_type}: Skipping " + "registration." + ) def get_service_connector_type( self, @@ -201,54 +205,61 @@ def instantiate_connector( def register_builtin_service_connectors(self) -> None: """Registers the default built-in service connectors.""" # Only register built-in service connectors once - if self.initialized: - return + with self.lock: + if self.initialized: + return - self.initialized = True + self.initialized = True - try: - from zenml.integrations.aws.service_connectors.aws_service_connector import ( # noqa - AWSServiceConnector, - ) - except ImportError as e: - logger.warning(f"Could not import AWS service connector: {e}.") + try: + from zenml.integrations.aws.service_connectors.aws_service_connector import ( # noqa + AWSServiceConnector, + ) + except ImportError as e: + logger.warning(f"Could not import AWS service connector: {e}.") - try: - from zenml.integrations.gcp.service_connectors.gcp_service_connector import ( # noqa - GCPServiceConnector, - ) - except ImportError as e: - logger.warning(f"Could not import GCP service connector: {e}.") + try: + from zenml.integrations.gcp.service_connectors.gcp_service_connector import ( # noqa + GCPServiceConnector, + ) + except ImportError as e: + logger.warning(f"Could not import GCP service connector: {e}.") - try: - from zenml.integrations.azure.service_connectors.azure_service_connector import ( # noqa - AzureServiceConnector, - ) - except ImportError as e: - logger.warning(f"Could not import Azure service connector: {e}.") + try: + from zenml.integrations.azure.service_connectors.azure_service_connector import ( # noqa + AzureServiceConnector, + ) + except ImportError as e: + logger.warning( + f"Could not import Azure service connector: {e}." + ) - try: - from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import ( # noqa - KubernetesServiceConnector, - ) - except ImportError as e: - logger.warning( - f"Could not import Kubernetes service connector: {e}." - ) + try: + from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import ( # noqa + KubernetesServiceConnector, + ) + except ImportError as e: + logger.warning( + f"Could not import Kubernetes service connector: {e}." + ) - try: - from zenml.service_connectors.docker_service_connector import ( # noqa - DockerServiceConnector, - ) - except ImportError as e: - logger.warning(f"Could not import Docker service connector: {e}.") + try: + from zenml.service_connectors.docker_service_connector import ( # noqa + DockerServiceConnector, + ) + except ImportError as e: + logger.warning( + f"Could not import Docker service connector: {e}." + ) - try: - from zenml.integrations.hyperai.service_connectors.hyperai_service_connector import ( # noqa - HyperAIServiceConnector, - ) - except ImportError as e: - logger.warning(f"Could not import HyperAI service connector: {e}.") + try: + from zenml.integrations.hyperai.service_connectors.hyperai_service_connector import ( # noqa + HyperAIServiceConnector, + ) + except ImportError as e: + logger.warning( + f"Could not import HyperAI service connector: {e}." + ) service_connector_registry = ServiceConnectorRegistry()