diff --git a/distributed/core.py b/distributed/core.py index 001f3bac31..7c331b4246 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -649,7 +649,6 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw msg = kwargs msg["reply"] = reply please_close = kwargs.get("close") - force_close = False if deserializers is None: deserializers = serializers if deserializers is not None: @@ -661,15 +660,15 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw response = await comm.read(deserializers=deserializers) else: response = None - except EnvironmentError: - # On communication errors, we should simply close the communication - force_close = True - raise + except Exception as exc: + # If an exception occured we will need to close the comm, if possible. + # Otherwise the other end might wait for a reply while this end is + # reusing the comm for something else. + comm.abort() + raise exc finally: if please_close: await comm.close() - elif force_close: - comm.abort() if isinstance(response, dict) and response.get("status") == "uncaught-error": if comm.deserialize: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index fc758da233..0044c7e2e6 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1783,3 +1783,32 @@ def f(x): return threading.get_ident() == x._thread_ident assert await c.submit(f, x) + + +@gen_cluster(client=True) +async def test_get_data_faulty_dep(c, s, a, b): + """This test creates a broken dependency and forces serialization by + requiring it to be submitted to another worker. The computation should + eventually finish by flagging the dep as bad and raise an appropriate + exception. + """ + + class BrokenDeserialization: + def __setstate__(self, *state): + raise AttributeError() + + def __getstate__(self, *args): + return "" + + def create(): + return BrokenDeserialization() + + def collect(*args): + return args + + fut1 = c.submit(create, workers=[a.name]) + + fut2 = c.submit(collect, fut1, workers=[b.name]) + + with pytest.raises(ValueError, match="Could not find dependent create-"): + await fut2.result() diff --git a/distributed/worker.py b/distributed/worker.py index 205369623d..7ed3ad44d0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -414,7 +414,6 @@ def __init__( ) self.total_comm_nbytes = 10e6 self.comm_nbytes = 0 - self._missing_dep_flight = set() self.threads = dict() @@ -1347,6 +1346,16 @@ async def get_data( compressed = await comm.write(msg, serializers=serializers) response = await comm.read(deserializers=serializers) assert response == "OK", response + + except CommClosedError: + logger.exception( + "Other end hung up during get_data with %s -> %s", + self.address, + who, + exc_info=True, + ) + comm.abort() + raise except EnvironmentError: logger.exception( "failed during get data with %s -> %s", self.address, who, exc_info=True @@ -1573,10 +1582,6 @@ def transition_flight_waiting(self, ts, worker=None, remove=True): except KeyError: pass - if not ts.who_has: - if ts.key not in self._missing_dep_flight: - self._missing_dep_flight.add(ts.key) - self.loop.add_callback(self.handle_missing_dep, ts) for dependent in ts.dependents: if dependent.state == "waiting": if remove: # try a new worker immediately @@ -1841,14 +1846,7 @@ def ensure_communicating(self): missing_deps = {dep for dep in deps if not dep.who_has} if missing_deps: logger.info("Can't find dependencies for key %s", key) - missing_deps2 = { - dep - for dep in missing_deps - if dep.key not in self._missing_dep_flight - } - for dep in missing_deps2: - self._missing_dep_flight.add(dep.key) - self.loop.add_callback(self.handle_missing_dep, *missing_deps2) + self.loop.add_callback(self.handle_missing_dep, *missing_deps) deps = [dep for dep in deps if dep not in missing_deps] @@ -2136,7 +2134,6 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if not busy: self.repetitively_busy = 0 - self.ensure_communicating() else: # Exponential backoff to avoid hammering scheduler/worker self.repetitively_busy += 1 @@ -2144,7 +2141,7 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): # See if anyone new has the data await self.query_who_has(dep.key) - self.ensure_communicating() + self.ensure_communicating() def bad_dep(self, dep): exc = ValueError( @@ -2164,7 +2161,7 @@ async def handle_missing_dep(self, *deps, **kwargs): if not deps: return - for dep in deps: + for dep in list(deps): if dep.suspicious_count > 5: deps.remove(dep) self.bad_dep(dep) @@ -2177,10 +2174,7 @@ async def handle_missing_dep(self, *deps, **kwargs): dep.key, dep.suspicious_count, ) - - who_has = await retry_operation( - self.scheduler.who_has, keys=list(dep.key for dep in deps) - ) + who_has = await self.query_who_has(list(dep.key for dep in deps)) who_has = {k: v for k, v in who_has.items() if v} self.update_who_has(who_has) for dep in deps: @@ -2188,7 +2182,6 @@ async def handle_missing_dep(self, *deps, **kwargs): if not who_has.get(dep.key): self.log.append((dep.key, "no workers found", dep.dependents)) - self.release_key(dep.key) else: self.log.append((dep.key, "new workers found")) for dependent in dep.dependents: @@ -2204,12 +2197,6 @@ async def handle_missing_dep(self, *deps, **kwargs): else: raise finally: - try: - for dep in deps: - self._missing_dep_flight.remove(dep.key) - except KeyError: - pass - self.ensure_communicating() async def query_who_has(self, *deps): @@ -2266,15 +2253,6 @@ def release_key(self, key, cause=None, reason=None, report=True): if key in self.actors and not ts.dependents: del self.actors[key] - # for any dependencies of key we are releasing remove task as dependent - for dependency in ts.dependencies: - dependency.dependents.discard(ts) - if not dependency.dependents and dependency.state in ( - "waiting", - "flight", - ): - self.release_key(dependency.key) - for worker in ts.who_has: self.has_what[worker].discard(ts.key) @@ -2949,7 +2927,6 @@ def validate_state(self): assert ( ts_wait.state == "flight" or ts_wait.state == "waiting" - or ts.wait.key in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ) if ts.state == "memory": @@ -3260,7 +3237,7 @@ async def _get_data(): except KeyError: raise ValueError("Unexpected response", response) else: - if status == "OK": + if not comm.closed() and status == "OK": await comm.write("OK") return response finally: