Skip to content

Commit

Permalink
Forward exception for failed dependencies
Browse files Browse the repository at this point in the history
Ensure that if there is a bad dependency, ensure the graph is aborted
and the exception is eventually forwarded to the client.
  • Loading branch information
fjetter committed Dec 14, 2020
1 parent 0797072 commit 9e29ace
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 45 deletions.
13 changes: 6 additions & 7 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw
msg = kwargs
msg["reply"] = reply
please_close = kwargs.get("close")
force_close = False
if deserializers is None:
deserializers = serializers
if deserializers is not None:
Expand All @@ -661,15 +660,15 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw
response = await comm.read(deserializers=deserializers)
else:
response = None
except EnvironmentError:
# On communication errors, we should simply close the communication
force_close = True
raise
except Exception as exc:
# If an exception occured we will need to close the comm, if possible.
# Otherwise the other end might wait for a reply while this end is
# reusing the comm for something else.
comm.abort()
raise exc
finally:
if please_close:
await comm.close()
elif force_close:
comm.abort()

if isinstance(response, dict) and response.get("status") == "uncaught-error":
if comm.deserialize:
Expand Down
29 changes: 29 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,3 +1783,32 @@ def f(x):
return threading.get_ident() == x._thread_ident

assert await c.submit(f, x)


@gen_cluster(client=True)
async def test_get_data_faulty_dep(c, s, a, b):
"""This test creates a broken dependency and forces serialization by
requiring it to be submitted to another worker. The computation should
eventually finish by flagging the dep as bad and raise an appropriate
exception.
"""

class BrokenDeserialization:
def __setstate__(self, *state):
raise AttributeError()

def __getstate__(self, *args):
return ""

def create():
return BrokenDeserialization()

def collect(*args):
return args

fut1 = c.submit(create, workers=[a.name])

fut2 = c.submit(collect, fut1, workers=[b.name])

with pytest.raises(ValueError, match="Could not find dependent create-"):
await fut2.result()
53 changes: 15 additions & 38 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ def __init__(
)
self.total_comm_nbytes = 10e6
self.comm_nbytes = 0
self._missing_dep_flight = set()

self.threads = dict()

Expand Down Expand Up @@ -1347,6 +1346,16 @@ async def get_data(
compressed = await comm.write(msg, serializers=serializers)
response = await comm.read(deserializers=serializers)
assert response == "OK", response

except CommClosedError:
logger.exception(
"Other end hung up during get_data with %s -> %s",
self.address,
who,
exc_info=True,
)
comm.abort()
raise
except EnvironmentError:
logger.exception(
"failed during get data with %s -> %s", self.address, who, exc_info=True
Expand Down Expand Up @@ -1573,10 +1582,6 @@ def transition_flight_waiting(self, ts, worker=None, remove=True):
except KeyError:
pass

if not ts.who_has:
if ts.key not in self._missing_dep_flight:
self._missing_dep_flight.add(ts.key)
self.loop.add_callback(self.handle_missing_dep, ts)
for dependent in ts.dependents:
if dependent.state == "waiting":
if remove: # try a new worker immediately
Expand Down Expand Up @@ -1841,14 +1846,7 @@ def ensure_communicating(self):
missing_deps = {dep for dep in deps if not dep.who_has}
if missing_deps:
logger.info("Can't find dependencies for key %s", key)
missing_deps2 = {
dep
for dep in missing_deps
if dep.key not in self._missing_dep_flight
}
for dep in missing_deps2:
self._missing_dep_flight.add(dep.key)
self.loop.add_callback(self.handle_missing_dep, *missing_deps2)
self.loop.add_callback(self.handle_missing_dep, *missing_deps)

deps = [dep for dep in deps if dep not in missing_deps]

Expand Down Expand Up @@ -2136,15 +2134,14 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None):

if not busy:
self.repetitively_busy = 0
self.ensure_communicating()
else:
# Exponential backoff to avoid hammering scheduler/worker
self.repetitively_busy += 1
await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy)

# See if anyone new has the data
await self.query_who_has(dep.key)
self.ensure_communicating()
self.ensure_communicating()

def bad_dep(self, dep):
exc = ValueError(
Expand All @@ -2164,7 +2161,7 @@ async def handle_missing_dep(self, *deps, **kwargs):
if not deps:
return

for dep in deps:
for dep in list(deps):
if dep.suspicious_count > 5:
deps.remove(dep)
self.bad_dep(dep)
Expand All @@ -2177,18 +2174,14 @@ async def handle_missing_dep(self, *deps, **kwargs):
dep.key,
dep.suspicious_count,
)

who_has = await retry_operation(
self.scheduler.who_has, keys=list(dep.key for dep in deps)
)
who_has = await self.query_who_has(list(dep.key for dep in deps))
who_has = {k: v for k, v in who_has.items() if v}
self.update_who_has(who_has)
for dep in deps:
dep.suspicious_count += 1

if not who_has.get(dep.key):
self.log.append((dep.key, "no workers found", dep.dependents))
self.release_key(dep.key)
else:
self.log.append((dep.key, "new workers found"))
for dependent in dep.dependents:
Expand All @@ -2204,12 +2197,6 @@ async def handle_missing_dep(self, *deps, **kwargs):
else:
raise
finally:
try:
for dep in deps:
self._missing_dep_flight.remove(dep.key)
except KeyError:
pass

self.ensure_communicating()

async def query_who_has(self, *deps):
Expand Down Expand Up @@ -2266,15 +2253,6 @@ def release_key(self, key, cause=None, reason=None, report=True):
if key in self.actors and not ts.dependents:
del self.actors[key]

# for any dependencies of key we are releasing remove task as dependent
for dependency in ts.dependencies:
dependency.dependents.discard(ts)
if not dependency.dependents and dependency.state in (
"waiting",
"flight",
):
self.release_key(dependency.key)

for worker in ts.who_has:
self.has_what[worker].discard(ts.key)

Expand Down Expand Up @@ -2949,7 +2927,6 @@ def validate_state(self):
assert (
ts_wait.state == "flight"
or ts_wait.state == "waiting"
or ts.wait.key in self._missing_dep_flight
or ts_wait.who_has.issubset(self.in_flight_workers)
)
if ts.state == "memory":
Expand Down Expand Up @@ -3260,7 +3237,7 @@ async def _get_data():
except KeyError:
raise ValueError("Unexpected response", response)
else:
if status == "OK":
if not comm.closed() and status == "OK":
await comm.write("OK")
return response
finally:
Expand Down

0 comments on commit 9e29ace

Please sign in to comment.