diff --git a/distributed/client.py b/distributed/client.py index a09c2ca704..4036a889f2 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -7,6 +7,7 @@ import json import logging import os +import pickle import re import sys import threading @@ -629,18 +630,55 @@ class AllExit(Exception): def _handle_print(event): _, msg = event - if isinstance(msg, dict) and "args" in msg and "kwargs" in msg: - print(*msg["args"], **msg["kwargs"]) - else: + if not isinstance(msg, dict): + # someone must have manually logged a print event with a hand-crafted + # payload, rather than by calling worker.print(). In that case simply + # print the payload and hope it works. print(msg) + return + + args = msg.get("args") + if not isinstance(args, tuple): + # worker.print() will always send us a tuple of args, even if it's an + # empty tuple. + raise TypeError( + f"_handle_print: client received non-tuple print args: {args!r}" + ) + + file = msg.get("file") + if file == 1: + file = sys.stdout + elif file == 2: + file = sys.stderr + elif file is not None: + raise TypeError( + f"_handle_print: client received unsupported file kwarg: {file!r}" + ) + + print( + *args, sep=msg.get("sep"), end=msg.get("end"), file=file, flush=msg.get("flush") + ) def _handle_warn(event): _, msg = event - if isinstance(msg, dict) and "args" in msg and "kwargs" in msg: - warnings.warn(*msg["args"], **msg["kwargs"]) - else: + if not isinstance(msg, dict): + # someone must have manually logged a warn event with a hand-crafted + # payload, rather than by calling worker.warn(). In that case simply + # warn the payload and hope it works. warnings.warn(msg) + else: + if "message" not in msg: + # TypeError makes sense here because it's analogous to calling a + # function without a required positional argument + raise TypeError( + "_handle_warn: client received a warn event missing the required " + '"message" argument.' + ) + warnings.warn( + pickle.loads(msg["message"]), + category=pickle.loads(msg.get("category", None)), + ) def _maybe_call_security_loader(address): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 142c746e73..a730f24584 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -7262,32 +7262,147 @@ async def test_log_event_warn(c, s, a, b): def foo(): get_worker().log_event(["foo", "warn"], "Hello!") - with pytest.warns(Warning, match="Hello!"): + with pytest.warns(UserWarning, match="Hello!"): await c.submit(foo) + def bar(): + # missing "message" key should log TypeError + get_worker().log_event("warn", {}) + + with captured_logger(logging.getLogger("distributed.client")) as log: + await c.submit(bar) + assert "TypeError" in log.getvalue() + @gen_cluster(client=True) async def test_log_event_warn_dask_warns(c, s, a, b): from dask.distributed import warn - def foo(): + def warn_simple(): warn("Hello!") - with pytest.warns(Warning, match="Hello!"): - await c.submit(foo) + with pytest.warns(UserWarning, match="Hello!"): + await c.submit(warn_simple) + + def warn_deprecation_1(): + # one way to do it... + warn("You have been deprecated by AI", DeprecationWarning) + + with pytest.warns(DeprecationWarning, match="You have been deprecated by AI"): + await c.submit(warn_deprecation_1) + + def warn_deprecation_2(): + # another way to do it... + warn(DeprecationWarning("Your profession has been deprecated")) + + with pytest.warns(DeprecationWarning, match="Your profession has been deprecated"): + await c.submit(warn_deprecation_2) + + # user-defined warning subclass + class MyPrescientWarning(UserWarning): + pass + + def warn_cassandra(): + warn(MyPrescientWarning("Cassandra says...")) + + with pytest.warns(MyPrescientWarning, match="Cassandra says..."): + await c.submit(warn_cassandra) @gen_cluster(client=True, Worker=Nanny) -async def test_print(c, s, a, b, capsys): +async def test_print_remote(c, s, a, b, capsys): from dask.distributed import print def foo(): + print("Hello!", 123) + + def bar(): print("Hello!", 123, sep=":") + def baz(): + print("Hello!", 123, sep=":", end="") + + def frotz(): + # like builtin print(), None values for kwargs should be same as + # defaults " ", "\n", sys.stdout, False, respectively. + # (But note we don't really have a good way to test for flushes.) + print("Hello!", 123, sep=None, end=None, file=None, flush=None) + + def plugh(): + # no positional arguments + print(sep=":", end=".") + + def print_stdout(): + print("meow", file=sys.stdout) + + def print_stderr(): + print("meow", file=sys.stderr) + + def print_badfile(): + print("meow", file="my arbitrary file object") + + capsys.readouterr() # drop any output captured so far + await c.submit(foo) + out, err = capsys.readouterr() + assert "Hello! 123\n" == out + + await c.submit(bar) + out, err = capsys.readouterr() + assert "Hello!:123\n" == out + await c.submit(baz) out, err = capsys.readouterr() - assert "Hello!:123" in out + assert "Hello!:123" == out + + await c.submit(frotz) + out, err = capsys.readouterr() + assert "Hello! 123\n" == out + + await c.submit(plugh) + out, err = capsys.readouterr() + assert "." == out + + await c.submit(print_stdout) + out, err = capsys.readouterr() + assert "meow\n" == out and "" == err + + await c.submit(print_stderr) + out, err = capsys.readouterr() + assert "meow\n" == err and "" == out + + with pytest.raises(TypeError): + await c.submit(print_badfile) + + +@gen_cluster(client=True, Worker=Nanny) +async def test_print_manual(c, s, a, b, capsys): + def foo(): + get_worker().log_event("print", "Hello!") + + capsys.readouterr() # drop any output captured so far + + await c.submit(foo) + out, err = capsys.readouterr() + assert "Hello!\n" == out + + def print_otherfile(): + # this should log a TypeError in the client + get_worker().log_event("print", {"args": ("hello",), "file": "bad value"}) + + with captured_logger(logging.getLogger("distributed.client")) as log: + await c.submit(print_otherfile) + assert "TypeError" in log.getvalue() + + +@gen_cluster(client=True, Worker=Nanny) +async def test_print_manual_bad_args(c, s, a, b, capsys): + def foo(): + get_worker().log_event("print", {"args": "not a tuple"}) + + with captured_logger(logging.getLogger("distributed.client")) as log: + await c.submit(foo) + assert "TypeError" in log.getvalue() @gen_cluster(client=True, Worker=Nanny) @@ -7303,13 +7418,14 @@ def foo(): assert " None: + """ + A drop-in replacement of the built-in ``print`` function for remote printing + from workers to clients. If called from outside a dask worker, its arguments + are passed directly to ``builtins.print()``. If called by code running on a + worker, then in addition to printing locally, any clients connected + (possibly remotely) to the scheduler managing this worker will receive an + event instructing them to print the same output to their own standard output + or standard error streams. For example, the user can perform simple + debugging of remote computations by including calls to this ``print`` + function in the submitted code and inspecting the output in a local Jupyter + notebook or interpreter session. + + All arguments behave the same as those of ``builtins.print()``, with the + exception that the ``file`` keyword argument, if specified, must either be + ``sys.stdout`` or ``sys.stderr``; arbitrary file-like objects are not + allowed. + + All non-keyword arguments are converted to strings using ``str()`` and + written to the stream, separated by ``sep`` and followed by ``end``. Both + ``sep`` and ``end`` must be strings; they can also be ``None``, which means + to use the default values. If no objects are given, ``print()`` will just + write ``end``. + + Parameters + ---------- + sep : str, optional + String inserted between values, default a space. + end : str, optional + String appended after the last value, default a newline. + file : ``sys.stdout`` or ``sys.stderr``, optional + Defaults to the current sys.stdout. + flush : bool, default False + Whether to forcibly flush the stream. + + Examples + -------- + >>> from dask.distributed import Client, print + >>> client = distributed.Client(...) + >>> def worker_function(): + ... print("Hello from worker!") + >>> client.submit(worker_function) + + Hello from worker! """ try: worker = get_worker() except ValueError: pass else: + # We are in a worker: prepare all of the print args and kwargs to be + # serialized over the wire to the client. msg = { - "args": tuple(stringify(arg) for arg in args), - "kwargs": {k: stringify(v) for k, v in kwargs.items()}, + # According to the Python stdlib docs, builtin print() simply calls + # str() on each positional argument, so we do the same here. + "args": tuple(map(str, args)), + "sep": sep, + "end": end, + "flush": flush, } + if file == sys.stdout: + msg["file"] = 1 # type: ignore + elif file == sys.stderr: + msg["file"] = 2 # type: ignore + elif file is not None: + raise TypeError( + f"Remote printing to arbitrary file objects is not supported. file " + f"kwarg must be one of None, sys.stdout, or sys.stderr; got: {file!r}" + ) worker.log_event("print", msg) - builtins.print(*args, **kwargs) + builtins.print(*args, sep=sep, end=end, file=file, flush=flush) -def warn(*args, **kwargs): - """Dask warn function - This raises a warning both wherever this function is run, and also - in the user's client session +def warn( + message: str | Warning, + category: type[Warning] | None = UserWarning, + stacklevel: int = 1, + source: Any = None, +) -> None: + """ + A drop-in replacement of the built-in ``warnings.warn()`` function for + issuing warnings remotely from workers to clients. + + If called from outside a dask worker, its arguments are passed directly to + ``warnings.warn()``. If called by code running on a worker, then in addition + to emitting a warning locally, any clients connected (possibly remotely) to + the scheduler managing this worker will receive an event instructing them to + emit the same warning (subject to their own local filters, etc.). When + implementing computations that may run on a worker, the user can call this + ``warn`` function to ensure that any remote client sessions will see their + warnings, for example in a Jupyter output cell. + + While all of the arguments are respected by the locally emitted warning + (with same meanings as in ``warnings.warn()``), ``stacklevel`` and + ``source`` are ignored by clients because they would not be meaningful in + the client's thread. + + Examples + -------- + >>> from dask.distributed import Client, warn + >>> client = Client() + >>> def do_warn(): + ... warn("A warning from a worker.") + >>> client.submit(do_warn).result() + /path/to/distributed/client.py:678: UserWarning: A warning from a worker. """ try: worker = get_worker() except ValueError: # pragma: no cover pass else: - worker.log_event("warn", {"args": args, "kwargs": kwargs}) + # We are in a worker: log a warn event with args serialized to the + # client. We have to pickle message and category into bytes ourselves + # because msgpack cannot handle them. The expectations is that these are + # always small objects. + worker.log_event( + "warn", + { + "message": pickle.dumps(message), + "category": pickle.dumps(category), + # We ignore stacklevel because it will be meaningless in the + # client's thread/process. + # We ignore source because we don't want to serialize arbitrary + # objects. + }, + ) - warnings.warn(*args, **kwargs) + # add 1 to stacklevel so that, at least in the worker's local stderr, we'll + # see the source line that called us + warnings.warn(message, category, stacklevel + 1, source) def benchmark_disk( diff --git a/docs/source/api.rst b/docs/source/api.rst index 46dc4da8ba..5a6fc842cb 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -23,6 +23,8 @@ The client connects to and submits computation to a Dask cluster (such as a :cla get_client secede rejoin + print + warn Reschedule .. currentmodule:: distributed.recreate_tasks @@ -184,6 +186,8 @@ Other .. autofunction:: distributed.get_client .. autofunction:: distributed.secede .. autofunction:: distributed.rejoin +.. autofunction:: distributed.print +.. autofunction:: distributed.warn .. autoclass:: distributed.Reschedule .. autoclass:: get_task_stream .. autoclass:: get_task_metadata