Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamped implementations of remote print() and warn(), fixing #7095 #7129

Merged
merged 7 commits into from
Oct 18, 2022
50 changes: 44 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
import pickle
import re
import sys
import threading
Expand Down Expand Up @@ -629,18 +630,55 @@ class AllExit(Exception):

def _handle_print(event):
_, msg = event
if isinstance(msg, dict) and "args" in msg and "kwargs" in msg:
print(*msg["args"], **msg["kwargs"])
else:
if not isinstance(msg, dict):
# someone must have manually logged a print event with a hand-crafted
# payload, rather than by calling worker.print(). In that case simply
# print the payload and hope it works.
print(msg)
return

args = msg.get("args")
if not isinstance(args, tuple):
# worker.print() will always send us a tuple of args, even if it's an
# empty tuple.
raise TypeError(
f"_handle_print: client received non-tuple print args: {args!r}"
)

file = msg.get("file")
if file == 1:
file = sys.stdout
elif file == 2:
file = sys.stderr
elif file is not None:
raise TypeError(
f"_handle_print: client received unsupported file kwarg: {file!r}"
)

print(
*args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush")
)


def _handle_warn(event):
_, msg = event
if isinstance(msg, dict) and "args" in msg and "kwargs" in msg:
warnings.warn(*msg["args"], **msg["kwargs"])
else:
if not isinstance(msg, dict):
# someone must have manually logged a warn event with a hand-crafted
# payload, rather than by calling worker.warn(). In that case simply
# warn the payload and hope it works.
warnings.warn(msg)
else:
if "message" not in msg:
# TypeError makes sense here because it's analogous to calling a
# function without a required positional argument
raise TypeError(
"_handle_warn: client received a warn event missing the required "
'"message" argument.'
)
warnings.warn(
pickle.loads(msg["message"]),
category=pickle.loads(msg.get("category", None)),
)


def _maybe_call_security_loader(address):
Expand Down
134 changes: 125 additions & 9 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7262,32 +7262,147 @@ async def test_log_event_warn(c, s, a, b):
def foo():
get_worker().log_event(["foo", "warn"], "Hello!")

with pytest.warns(Warning, match="Hello!"):
with pytest.warns(UserWarning, match="Hello!"):
await c.submit(foo)

def bar():
# missing "message" key should log TypeError
get_worker().log_event("warn", {})

with captured_logger(logging.getLogger("distributed.client")) as log:
await c.submit(bar)
assert "TypeError" in log.getvalue()


@gen_cluster(client=True)
async def test_log_event_warn_dask_warns(c, s, a, b):
from dask.distributed import warn

def foo():
def warn_simple():
warn("Hello!")

with pytest.warns(Warning, match="Hello!"):
await c.submit(foo)
with pytest.warns(UserWarning, match="Hello!"):
await c.submit(warn_simple)

def warn_deprecation_1():
# one way to do it...
warn("You have been deprecated by AI", DeprecationWarning)

with pytest.warns(DeprecationWarning, match="You have been deprecated by AI"):
await c.submit(warn_deprecation_1)

def warn_deprecation_2():
# another way to do it...
warn(DeprecationWarning("Your profession has been deprecated"))

with pytest.warns(DeprecationWarning, match="Your profession has been deprecated"):
await c.submit(warn_deprecation_2)

# user-defined warning subclass
class MyPrescientWarning(UserWarning):
pass

def warn_cassandra():
warn(MyPrescientWarning("Cassandra says..."))

with pytest.warns(MyPrescientWarning, match="Cassandra says..."):
await c.submit(warn_cassandra)


@gen_cluster(client=True, Worker=Nanny)
async def test_print(c, s, a, b, capsys):
async def test_print_remote(c, s, a, b, capsys):
from dask.distributed import print

def foo():
print("Hello!", 123)

def bar():
print("Hello!", 123, sep=":")

def baz():
print("Hello!", 123, sep=":", end="")

def frotz():
# like builtin print(), None values for kwargs should be same as
# defaults " ", "\n", sys.stdout, False, respectively.
# (But note we don't really have a good way to test for flushes.)
print("Hello!", 123, sep=None, end=None, file=None, flush=None)

def plugh():
# no positional arguments
print(sep=":", end=".")

def print_stdout():
print("meow", file=sys.stdout)

def print_stderr():
print("meow", file=sys.stderr)

def print_badfile():
print("meow", file="my arbitrary file object")

capsys.readouterr() # drop any output captured so far

await c.submit(foo)
out, err = capsys.readouterr()
assert "Hello! 123\n" == out

await c.submit(bar)
out, err = capsys.readouterr()
assert "Hello!:123\n" == out

await c.submit(baz)
out, err = capsys.readouterr()
assert "Hello!:123" in out
assert "Hello!:123" == out

await c.submit(frotz)
out, err = capsys.readouterr()
assert "Hello! 123\n" == out

await c.submit(plugh)
out, err = capsys.readouterr()
assert "." == out

await c.submit(print_stdout)
out, err = capsys.readouterr()
assert "meow\n" == out and "" == err

await c.submit(print_stderr)
out, err = capsys.readouterr()
assert "meow\n" == err and "" == out

with pytest.raises(TypeError):
await c.submit(print_badfile)


@gen_cluster(client=True, Worker=Nanny)
async def test_print_manual(c, s, a, b, capsys):
def foo():
get_worker().log_event("print", "Hello!")

capsys.readouterr() # drop any output captured so far

await c.submit(foo)
out, err = capsys.readouterr()
assert "Hello!\n" == out

def print_otherfile():
# this should log a TypeError in the client
get_worker().log_event("print", {"args": ("hello",), "file": "bad value"})

with captured_logger(logging.getLogger("distributed.client")) as log:
await c.submit(print_otherfile)
assert "TypeError" in log.getvalue()


@gen_cluster(client=True, Worker=Nanny)
async def test_print_manual_bad_args(c, s, a, b, capsys):
def foo():
get_worker().log_event("print", {"args": "not a tuple"})

with captured_logger(logging.getLogger("distributed.client")) as log:
await c.submit(foo)
assert "TypeError" in log.getvalue()


@gen_cluster(client=True, Worker=Nanny)
Expand All @@ -7303,13 +7418,14 @@ def foo():
assert "<object object at" in out


def test_print_simple(capsys):
def test_print_local(capsys):
from dask.distributed import print

print("Hello!", 123, sep=":")
capsys.readouterr() # drop any output captured so far

print("Hello!", 123, sep=":")
out, err = capsys.readouterr()
assert "Hello!:123" in out
assert "Hello!:123\n" == out


def _verify_cluster_dump(
Expand Down
Loading