From e1ed03607454bf010b5799eb767069bc765da007 Mon Sep 17 00:00:00 2001 From: Ajay Patel Date: Tue, 30 Jul 2024 21:33:29 -0400 Subject: [PATCH] Fix retriever test --- pyproject.toml | 2 +- src/utils/background_utils.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e9dcd8a..822e7cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ warn_unused_ignores = true mypy_path = "src/_stubs" [[tool.mypy.overrides]] -module = "click,wandb,wandb.*,click.testing,flaky,tensorflow,torch_xla,jax,datasets.features.features,datasets.iterable_dataset,datasets.fingerprint,datasets.builder,datasets.arrow_writer,datasets.splits,datasets.utils,datasets.utils.version,pyarrow.lib,huggingface_hub,huggingface_hub.utils._headers,huggingface_hub.utils._errors,dill,dill.source,transformers,bitsandbytes,sqlitedict,optimum.bettertransformer,optimum.bettertransformer.models,optimum.utils,transformers.utils.quantization_config,sortedcontainers,peft,psutil,ring,ctransformers,petals,petals.client.inference_session,hivemind.p2p.p2p_daemon_bindings.utils,huggingface_hub.utils,tqdm,ctransformers.transformers,vllm,litellm,litellm.llms.palm,litellm.exceptions,sentence_transformers,faiss,huggingface_hub.utils._validators,evaluate,transformers.trainer_callback,transformers.training_args,trl,guidance,sentence_transformers.models.Transformer,trl.trainer.utils,transformers.trainer_utils,setfit,joblib,setfit.modeling,transformers.utils.notebook,mistralai.models.chat_completion,accelerate.utils,accelerate.utils.constants,accelerate,transformers.trainer,sentence_transformers.util,Pyro5,Pyro5.server,Pyro5.api,Pyro5,datadreamer,huggingface_hub.repocard,transformers.trainer_pt_utils,traitlets.utils.warnings,orjson" +module = "click,wandb,wandb.*,click.testing,flaky,tensorflow,torch_xla,jax,datasets.features.features,datasets.iterable_dataset,datasets.fingerprint,datasets.builder,datasets.arrow_writer,datasets.splits,datasets.utils,datasets.utils.version,pyarrow.lib,huggingface_hub,huggingface_hub.utils._headers,huggingface_hub.utils._errors,dill,dill.source,transformers,bitsandbytes,sqlitedict,optimum.bettertransformer,optimum.bettertransformer.models,optimum.utils,transformers.utils.quantization_config,sortedcontainers,peft,psutil,ring,ctransformers,petals,petals.client.inference_session,hivemind.p2p.p2p_daemon_bindings.utils,huggingface_hub.utils,tqdm,ctransformers.transformers,vllm,litellm,litellm.llms.palm,litellm.exceptions,sentence_transformers,faiss,huggingface_hub.utils._validators,evaluate,transformers.trainer_callback,transformers.training_args,trl,guidance,sentence_transformers.models.Transformer,trl.trainer.utils,transformers.trainer_utils,setfit,joblib,setfit.modeling,transformers.utils.notebook,mistralai.models.chat_completion,accelerate.utils,accelerate.utils.constants,accelerate,transformers.trainer,sentence_transformers.util,Pyro5,Pyro5.server,Pyro5.api,Pyro5,datadreamer,huggingface_hub.repocard,transformers.trainer_pt_utils,traitlets.utils.warnings,orjson,Pyro5.errors" ignore_missing_imports = true [tool.pyright] diff --git a/src/utils/background_utils.py b/src/utils/background_utils.py index d7cd0e9..38c29c6 100644 --- a/src/utils/background_utils.py +++ b/src/utils/background_utils.py @@ -17,6 +17,7 @@ import Pyro5.api import Pyro5.server import torch +from Pyro5.errors import CommunicationError from .. import logging as datadreamer_logging @@ -300,11 +301,16 @@ def find_free_port(): # pragma: no cover return s.getsockname()[1] -def wait_for_port(port: int): # pragma: no cover +def wait_for_port(port: int, _proxy=None): # pragma: no cover while True: try: with socket.create_connection(("localhost", port), timeout=5): - break + try: + if _proxy is not None: + _proxy._Proxy__pyroCreateConnection() + break + except CommunicationError: + pass except OSError: pass @@ -383,7 +389,7 @@ def client_pickling_wrapper(orig_meth, *args, **kwargs): if run_in_background: # pragma: no cover self._proxy = _proxy self.process = process - wait_for_port(free_port) + wait_for_port(free_port, _proxy=_proxy) def __del__(self): if run_in_background: # pragma: no cover