Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix service connector type registry race conditions #3202

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 68 additions & 57 deletions src/zenml/service_connectors/service_connector_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Implementation of a service connector registry."""

import threading
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from zenml.logger import get_logger
Expand All @@ -34,6 +35,7 @@ def __init__(self) -> None:
"""Initialize the service connector registry."""
self.service_connector_types: Dict[str, ServiceConnectorTypeModel] = {}
self.initialized = False
self.lock = threading.RLock()

def register_service_connector_type(
self,
Expand All @@ -46,23 +48,25 @@ def register_service_connector_type(
service_connector_type: Service connector type.
overwrite: Whether to overwrite an existing service connector type.
"""
if (
service_connector_type.connector_type
not in self.service_connector_types
or overwrite
):
self.service_connector_types[
with self.lock:
if (
service_connector_type.connector_type
] = service_connector_type
logger.debug(
"Registered service connector type "
f"{service_connector_type.connector_type}."
)
else:
logger.debug(
f"Found existing service connector for type "
f"{service_connector_type.connector_type}: Skipping registration."
)
not in self.service_connector_types
or overwrite
):
self.service_connector_types[
service_connector_type.connector_type
] = service_connector_type
logger.debug(
"Registered service connector type "
f"{service_connector_type.connector_type}."
)
else:
logger.debug(
f"Found existing service connector for type "
f"{service_connector_type.connector_type}: Skipping "
"registration."
)

def get_service_connector_type(
self,
Expand Down Expand Up @@ -201,54 +205,61 @@ def instantiate_connector(
def register_builtin_service_connectors(self) -> None:
"""Registers the default built-in service connectors."""
# Only register built-in service connectors once
if self.initialized:
return
with self.lock:
if self.initialized:
return

self.initialized = True
self.initialized = True

try:
from zenml.integrations.aws.service_connectors.aws_service_connector import ( # noqa
AWSServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import AWS service connector: {e}.")
try:
from zenml.integrations.aws.service_connectors.aws_service_connector import ( # noqa
AWSServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import AWS service connector: {e}.")

try:
from zenml.integrations.gcp.service_connectors.gcp_service_connector import ( # noqa
GCPServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import GCP service connector: {e}.")
try:
from zenml.integrations.gcp.service_connectors.gcp_service_connector import ( # noqa
GCPServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import GCP service connector: {e}.")

try:
from zenml.integrations.azure.service_connectors.azure_service_connector import ( # noqa
AzureServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import Azure service connector: {e}.")
try:
from zenml.integrations.azure.service_connectors.azure_service_connector import ( # noqa
AzureServiceConnector,
)
except ImportError as e:
logger.warning(
f"Could not import Azure service connector: {e}."
)

try:
from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import ( # noqa
KubernetesServiceConnector,
)
except ImportError as e:
logger.warning(
f"Could not import Kubernetes service connector: {e}."
)
try:
from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import ( # noqa
KubernetesServiceConnector,
)
except ImportError as e:
logger.warning(
f"Could not import Kubernetes service connector: {e}."
)

try:
from zenml.service_connectors.docker_service_connector import ( # noqa
DockerServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import Docker service connector: {e}.")
try:
from zenml.service_connectors.docker_service_connector import ( # noqa
DockerServiceConnector,
)
except ImportError as e:
logger.warning(
f"Could not import Docker service connector: {e}."
)

try:
from zenml.integrations.hyperai.service_connectors.hyperai_service_connector import ( # noqa
HyperAIServiceConnector,
)
except ImportError as e:
logger.warning(f"Could not import HyperAI service connector: {e}.")
try:
from zenml.integrations.hyperai.service_connectors.hyperai_service_connector import ( # noqa
HyperAIServiceConnector,
)
except ImportError as e:
logger.warning(
f"Could not import HyperAI service connector: {e}."
)


service_connector_registry = ServiceConnectorRegistry()
Loading