Skip to content

Commit

Permalink
Revamped implementations of remote print() and warn(), fixing #7095
Browse files Browse the repository at this point in the history
… (#7129)
  • Loading branch information
maxbane authored Oct 18, 2022
1 parent ad879fb commit 334af11
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 30 deletions.
50 changes: 44 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
import pickle
import re
import sys
import threading
Expand Down Expand Up @@ -629,18 +630,55 @@ class AllExit(Exception):

def _handle_print(event):
_, msg = event
if isinstance(msg, dict) and "args" in msg and "kwargs" in msg:
print(*msg["args"], **msg["kwargs"])
else:
if not isinstance(msg, dict):
# someone must have manually logged a print event with a hand-crafted
# payload, rather than by calling worker.print(). In that case simply
# print the payload and hope it works.
print(msg)
return

args = msg.get("args")
if not isinstance(args, tuple):
# worker.print() will always send us a tuple of args, even if it's an
# empty tuple.
raise TypeError(
f"_handle_print: client received non-tuple print args: {args!r}"
)

file = msg.get("file")
if file == 1:
file = sys.stdout
elif file == 2:
file = sys.stderr
elif file is not None:
raise TypeError(
f"_handle_print: client received unsupported file kwarg: {file!r}"
)

print(
*args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush")
)


def _handle_warn(event):
_, msg = event
if isinstance(msg, dict) and "args" in msg and "kwargs" in msg:
warnings.warn(*msg["args"], **msg["kwargs"])
else:
if not isinstance(msg, dict):
# someone must have manually logged a warn event with a hand-crafted
# payload, rather than by calling worker.warn(). In that case simply
# warn the payload and hope it works.
warnings.warn(msg)
else:
if "message" not in msg:
# TypeError makes sense here because it's analogous to calling a
# function without a required positional argument
raise TypeError(
"_handle_warn: client received a warn event missing the required "
'"message" argument.'
)
warnings.warn(
pickle.loads(msg["message"]),
category=pickle.loads(msg.get("category", None)),
)


def _maybe_call_security_loader(address):
Expand Down
134 changes: 125 additions & 9 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7259,32 +7259,147 @@ async def test_log_event_warn(c, s, a, b):
def foo():
get_worker().log_event(["foo", "warn"], "Hello!")

with pytest.warns(Warning, match="Hello!"):
with pytest.warns(UserWarning, match="Hello!"):
await c.submit(foo)

def bar():
# missing "message" key should log TypeError
get_worker().log_event("warn", {})

with captured_logger(logging.getLogger("distributed.client")) as log:
await c.submit(bar)
assert "TypeError" in log.getvalue()


@gen_cluster(client=True)
async def test_log_event_warn_dask_warns(c, s, a, b):
from dask.distributed import warn

def foo():
def warn_simple():
warn("Hello!")

with pytest.warns(Warning, match="Hello!"):
await c.submit(foo)
with pytest.warns(UserWarning, match="Hello!"):
await c.submit(warn_simple)

def warn_deprecation_1():
# one way to do it...
warn("You have been deprecated by AI", DeprecationWarning)

with pytest.warns(DeprecationWarning, match="You have been deprecated by AI"):
await c.submit(warn_deprecation_1)

def warn_deprecation_2():
# another way to do it...
warn(DeprecationWarning("Your profession has been deprecated"))

with pytest.warns(DeprecationWarning, match="Your profession has been deprecated"):
await c.submit(warn_deprecation_2)

# user-defined warning subclass
class MyPrescientWarning(UserWarning):
pass

def warn_cassandra():
warn(MyPrescientWarning("Cassandra says..."))

with pytest.warns(MyPrescientWarning, match="Cassandra says..."):
await c.submit(warn_cassandra)


@gen_cluster(client=True, Worker=Nanny)
async def test_print(c, s, a, b, capsys):
async def test_print_remote(c, s, a, b, capsys):
from dask.distributed import print

def foo():
print("Hello!", 123)

def bar():
print("Hello!", 123, sep=":")

def baz():
print("Hello!", 123, sep=":", end="")

def frotz():
# like builtin print(), None values for kwargs should be same as
# defaults " ", "\n", sys.stdout, False, respectively.
# (But note we don't really have a good way to test for flushes.)
print("Hello!", 123, sep=None, end=None, file=None, flush=None)

def plugh():
# no positional arguments
print(sep=":", end=".")

def print_stdout():
print("meow", file=sys.stdout)

def print_stderr():
print("meow", file=sys.stderr)

def print_badfile():
print("meow", file="my arbitrary file object")

capsys.readouterr() # drop any output captured so far

await c.submit(foo)
out, err = capsys.readouterr()
assert "Hello! 123\n" == out

await c.submit(bar)
out, err = capsys.readouterr()
assert "Hello!:123\n" == out

await c.submit(baz)
out, err = capsys.readouterr()
assert "Hello!:123" in out
assert "Hello!:123" == out

await c.submit(frotz)
out, err = capsys.readouterr()
assert "Hello! 123\n" == out

await c.submit(plugh)
out, err = capsys.readouterr()
assert "." == out

await c.submit(print_stdout)
out, err = capsys.readouterr()
assert "meow\n" == out and "" == err

await c.submit(print_stderr)
out, err = capsys.readouterr()
assert "meow\n" == err and "" == out

with pytest.raises(TypeError):
await c.submit(print_badfile)


@gen_cluster(client=True, Worker=Nanny)
async def test_print_manual(c, s, a, b, capsys):
def foo():
get_worker().log_event("print", "Hello!")

capsys.readouterr() # drop any output captured so far

await c.submit(foo)
out, err = capsys.readouterr()
assert "Hello!\n" == out

def print_otherfile():
# this should log a TypeError in the client
get_worker().log_event("print", {"args": ("hello",), "file": "bad value"})

with captured_logger(logging.getLogger("distributed.client")) as log:
await c.submit(print_otherfile)
assert "TypeError" in log.getvalue()


@gen_cluster(client=True, Worker=Nanny)
async def test_print_manual_bad_args(c, s, a, b, capsys):
def foo():
get_worker().log_event("print", {"args": "not a tuple"})

with captured_logger(logging.getLogger("distributed.client")) as log:
await c.submit(foo)
assert "TypeError" in log.getvalue()


@gen_cluster(client=True, Worker=Nanny)
Expand All @@ -7300,13 +7415,14 @@ def foo():
assert "<object object at" in out


def test_print_simple(capsys):
def test_print_local(capsys):
from dask.distributed import print

print("Hello!", 123, sep=":")
capsys.readouterr() # drop any output captured so far

print("Hello!", 123, sep=":")
out, err = capsys.readouterr()
assert "Hello!:123" in out
assert "Hello!:123\n" == out


def _verify_cluster_dump(
Expand Down
Loading

0 comments on commit 334af11

Please sign in to comment.