diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index 65d916bf7de..0c0249dcc0b 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -53,7 +53,8 @@ dependencies: - pip: - git+https://github.com/dask/dask - git+https://github.com/dask/s3fs - - git+https://github.com/dask/zict + - git+https://github.com/ncclementi/zict@slow_raises #remove this after zict merged + #- git+https://github.com/dask/zict # FIXME https://github.com/dask/distributed/issues/5345 # - git+https://github.com/intake/filesystem_spec - git+https://github.com/joblib/joblib diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index dbe621d373b..66c0b5906eb 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -226,7 +226,7 @@ properties: http: type: object - decription: Settings for Dask's embedded HTTP Server + description: Settings for Dask's embedded HTTP Server properties: routes: type: array @@ -504,9 +504,17 @@ properties: When the process memory reaches this level the nanny process will kill the worker (if a nanny is present) + max-spill: + oneOf: + - type: string + - {type: number, minimum: 0} + - enum: [false] + description: >- + Limit of number of bytes to be spilled on disk. + http: type: object - decription: Settings for Dask's embedded HTTP Server + description: Settings for Dask's embedded HTTP Server properties: routes: type: array diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 6475e1cc223..07abe57d874 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -145,6 +145,10 @@ distributed: pause: 0.80 # fraction at which we pause worker threads terminate: 0.95 # fraction at which we terminate the worker + # Max size of the spill file on disk (e.g. "10 GB") + # Set to false for no maximum. + max-spill: false + http: routes: - distributed.http.worker.prometheus diff --git a/distributed/spill.py b/distributed/spill.py index 2c849c2447e..223b226ea82 100644 --- a/distributed/spill.py +++ b/distributed/spill.py @@ -1,76 +1,247 @@ from __future__ import annotations -from collections.abc import Hashable, Mapping +import logging +import time +from collections.abc import Mapping +from contextlib import contextmanager from functools import partial -from typing import Any +from typing import TYPE_CHECKING, Any -from zict import Buffer, File, Func +import zict +from packaging.version import parse as parse_version + +if TYPE_CHECKING: + from typing_extensions import Literal from .protocol import deserialize_bytes, serialize_bytelist from .sizeof import safe_sizeof +logger = logging.getLogger(__name__) + -class SpillBuffer(Buffer): +class SpillBuffer(zict.Buffer): """MutableMapping that automatically spills out dask key/value pairs to disk when - the total size of the stored data exceeds the target + the total size of the stored data exceeds the target. If max_spill is provided the + key/value pairs won't be spilled once this threshold has been reached. + + Paramaters + ---------- + spill_directory: str + Location on disk to write the spill files to + target: int + Managed memory, in bytes, to start spilling at + max_spill: int | False, optional + Limit of number of bytes to be spilled on disk. Set to False to disable. + min_log_interval: float, optional + Minimum interval, in seconds, between warnings on the log file about full disk """ - spilled_by_key: dict[Hashable, int] - spilled_total: int + last_logged: float + min_log_interval: float + logged_pickle_errors: set[str] + + def __init__( + self, + spill_directory: str, + target: int, + max_spill: int | Literal[False] = False, + min_log_interval: float = 2, + ): + + if max_spill is not False and parse_version(zict.__version__) <= parse_version( + "2.0.0" + ): + raise ValueError("zict > 2.0.0 required to set max_weight") - def __init__(self, spill_directory: str, target: int): - self.spilled_by_key = {} - self.spilled_total = 0 - storage = Func( - partial(serialize_bytelist, on_error="raise"), - deserialize_bytes, - File(spill_directory), - ) super().__init__( - {}, - storage, - target, - weight=self._weight, - fast_to_slow_callbacks=[self._on_evict], - slow_to_fast_callbacks=[self._on_retrieve], + fast={}, + slow=Slow(spill_directory, max_spill), + n=target, + weight=_in_memory_weight, ) + self.last_logged = 0 + self.min_log_interval = min_log_interval + self.logged_pickle_errors = set() # keys logged with pickle error + + @contextmanager + def handle_errors(self, key: str | None): + try: + yield + except MaxSpillExceeded as e: + # key is in self.fast; no keys have been lost on eviction + # Note: requires zict > 2.0 + (key_e,) = e.args + assert key_e in self.fast + assert key_e not in self.slow + now = time.time() + if now - self.last_logged >= self.min_log_interval: + logger.warning( + "Spill file on disk reached capacity; keeping data in memory" + ) + self.last_logged = now + raise HandledError() + except OSError: + # Typically, this is a disk full error + now = time.time() + if now - self.last_logged >= self.min_log_interval: + logger.error( + "Spill to disk failed; keeping data in memory", exc_info=True + ) + self.last_logged = now + raise HandledError() + except PickleError as e: + key_e, orig_e = e.args + assert key_e in self.fast + assert key_e not in self.slow + if key_e == key: + assert key is not None + # The key we just inserted failed to serialize. + # This happens only when the key is individually larger than target. + # The exception will be caught by Worker and logged; the status of + # the task will be set to error. + del self[key] + raise orig_e + else: + # The key we just inserted is smaller than target, but it caused + # another, unrelated key to be spilled out of the LRU, and that key + # failed to serialize. There's nothing wrong with the new key. The older + # key is still in memory. + if key_e not in self.logged_pickle_errors: + logger.error(f"Failed to pickle {key_e!r}", exc_info=True) + self.logged_pickle_errors.add(key_e) + raise HandledError() + + def __setitem__(self, key: str, value: Any) -> None: + """If sizeof(value) < target, write key/value pair to self.fast; this may in + turn cause older keys to be spilled from fast to slow. + If sizeof(value) >= target, write key/value pair directly to self.slow instead. + + Raises + ------ + Exception + sizeof(value) >= target, and value failed to pickle. + The key/value pair has been forgotten. + + In all other cases: + + - an older value was evicted and failed to pickle, + - this value or an older one caused the disk to fill and raise OSError, + - this value or an older one caused the max_spill threshold to be exceeded, + + this method does not raise and guarantees that the key/value that caused the + issue remained in fast. + """ + try: + with self.handle_errors(key): + super().__setitem__(key, value) + self.logged_pickle_errors.discard(key) + except HandledError: + assert key in self.fast + assert key not in self.slow + + def evict(self) -> int: + """Manually evict the oldest key/value pair, even if target has not been reached. + Returns sizeof(value). + If the eviction failed (value failed to pickle, disk full, or max_spill + exceeded), return -1; the key/value pair that caused the issue will remain in + fast. This method never raises. + """ + try: + with self.handle_errors(None): + _, _, weight = self.fast.evict() + return weight + except HandledError: + return -1 + + def __delitem__(self, key: str) -> None: + super().__delitem__(key) + self.logged_pickle_errors.discard(key) @property - def memory(self) -> Mapping[Hashable, Any]: + def memory(self) -> Mapping[str, Any]: """Key/value pairs stored in RAM. Alias of zict.Buffer.fast. For inspection only - do not modify directly! """ return self.fast @property - def disk(self) -> Mapping[Hashable, Any]: + def disk(self) -> Mapping[str, Any]: """Key/value pairs spilled out to disk. Alias of zict.Buffer.slow. For inspection only - do not modify directly! """ return self.slow - @staticmethod - def _weight(key: Hashable, value: Any) -> int: - return safe_sizeof(value) - - def _on_evict(self, key: Hashable, value: Any) -> None: - b = safe_sizeof(value) - self.spilled_by_key[key] = b - self.spilled_total += b - - def _on_retrieve(self, key: Hashable, value: Any) -> None: - self.spilled_total -= self.spilled_by_key.pop(key) - - def __setitem__(self, key: Hashable, value: Any) -> None: - self.spilled_total -= self.spilled_by_key.pop(key, 0) - super().__setitem__(key, value) - if key in self.slow: - # value is individually larger than target so it went directly to slow. - # _on_evict was not called. - b = safe_sizeof(value) - self.spilled_by_key[key] = b - self.spilled_total += b - - def __delitem__(self, key: Hashable) -> None: - self.spilled_total -= self.spilled_by_key.pop(key, 0) + @property + def spilled_total(self) -> int: + """Number of bytes spilled to disk. + Note that this is the pickled size, which may differ from the output of sizeof(). + """ + return self.slow.total_weight + + +def _in_memory_weight(key: str, value: Any) -> int: + return safe_sizeof(value) + + +# Internal exceptions. These are never raised by SpillBuffer. +class MaxSpillExceeded(Exception): + pass + + +class PickleError(Exception): + pass + + +class HandledError(Exception): + pass + + +class Slow(zict.Func): + max_weight: int | Literal[False] + weight_by_key: dict[str, int] + total_weight: int + + def __init__(self, spill_directory: str, max_weight: int | Literal[False] = False): + super().__init__( + partial(serialize_bytelist, on_error="raise"), + deserialize_bytes, + zict.File(spill_directory), + ) + self.max_weight = max_weight + self.weight_by_key = {} + self.total_weight = 0 + + def __setitem__(self, key: str, value: Any) -> None: + try: + pickled = self.dump(value) + except Exception as e: + # zict.LRU ensures that the key remains in fast if we raise. + # Wrap the exception so that it's recognizable by SpillBuffer, + # which will then unwrap it. + raise PickleError(key, e) + + pickled_size = sum(len(frame) for frame in pickled) + + # Thanks to Buffer.__setitem__, we never update existing keys in slow, + # but always delete them and reinsert them. + assert key not in self.d + assert key not in self.weight_by_key + + if ( + self.max_weight is not False + and self.total_weight + pickled_size > self.max_weight + ): + # Stop callbacks and ensure that the key ends up in SpillBuffer.fast + # To be caught by SpillBuffer.__setitem__ + raise MaxSpillExceeded(key) + + # Store to disk through File. + # This may raise OSError, which is caught by SpillBuffer above. + self.d[key] = pickled + + self.weight_by_key[key] = pickled_size + self.total_weight += pickled_size + + def __delitem__(self, key: str) -> None: super().__delitem__(key) + self.total_weight -= self.weight_by_key.pop(key) diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index d013b141158..713393cbb26 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -1,8 +1,23 @@ +from __future__ import annotations + +import logging +import os + import pytest +zict = pytest.importorskip("zict") +from packaging.version import parse as parse_version + from dask.sizeof import sizeof +from distributed.compatibility import WINDOWS +from distributed.protocol import serialize_bytelist from distributed.spill import SpillBuffer +from distributed.utils_test import captured_logger + + +def psize(*objs) -> int: + return sum(len(frame) for obj in objs for frame in serialize_bytelist(obj)) def test_spillbuffer(tmpdir): @@ -11,72 +26,327 @@ def test_spillbuffer(tmpdir): assert buf.memory is buf.fast assert buf.disk is buf.slow - assert not buf.spilled_by_key + assert not buf.slow.weight_by_key + assert buf.slow.total_weight == 0 assert buf.spilled_total == 0 - a, b, c, d = "a" * 100, "b" * 100, "c" * 100, "d" * 100 - s = sizeof(a) + a, b, c, d = "a" * 100, "b" * 99, "c" * 98, "d" * 97 + # Test assumption made by this test, mostly for non CPython implementations - assert 100 < s < 200 + assert 100 < sizeof(a) < 200 + assert sizeof(a) != psize(a) buf["a"] = a - assert not buf.disk - assert not buf.spilled_by_key - assert buf.spilled_total == 0 + assert not buf.slow + assert buf.fast.weights == {"a": sizeof(a)} + assert buf.fast.total_weight == sizeof(a) + assert buf.slow.weight_by_key == {} + assert buf.slow.total_weight == 0 assert buf["a"] == a buf["b"] = b - assert not buf.disk - assert not buf.spilled_by_key - assert buf.spilled_total == 0 + assert not buf.slow + assert not buf.slow.weight_by_key + assert buf.slow.total_weight == 0 buf["c"] = c - assert set(buf.disk) == {"a"} - assert buf.spilled_by_key == {"a": s} - assert buf.spilled_total == s + assert set(buf.slow) == {"a"} + assert buf.slow.weight_by_key == {"a": psize(a)} + assert buf.slow.total_weight == psize(a) assert buf["a"] == a - assert set(buf.disk) == {"b"} - assert buf.spilled_by_key == {"b": s} - assert buf.spilled_total == s + assert set(buf.slow) == {"b"} + assert buf.slow.weight_by_key == {"b": psize(b)} + assert buf.slow.total_weight == psize(b) buf["d"] = d - assert set(buf.disk) == {"b", "c"} - assert buf.spilled_by_key == {"b": s, "c": s} - assert buf.spilled_total == s * 2 + assert set(buf.slow) == {"b", "c"} + assert buf.slow.weight_by_key == {"b": psize(b), "c": psize(c)} + assert buf.slow.total_weight == psize(b, c) # Deleting an in-memory key does not automatically move spilled keys back to memory del buf["a"] - assert set(buf.disk) == {"b", "c"} - assert buf.spilled_by_key == {"b": s, "c": s} - assert buf.spilled_total == s * 2 + assert set(buf.slow) == {"b", "c"} + assert buf.slow.weight_by_key == {"b": psize(b), "c": psize(c)} + assert buf.slow.total_weight == psize(b, c) with pytest.raises(KeyError): buf["a"] # Deleting a spilled key updates the metadata del buf["b"] - assert set(buf.disk) == {"c"} - assert buf.spilled_by_key == {"c": s} - assert buf.spilled_total == s + assert set(buf.slow) == {"c"} + assert buf.slow.weight_by_key == {"c": psize(c)} + assert buf.slow.total_weight == psize(c) with pytest.raises(KeyError): buf["b"] # Updating a spilled key moves it to the top of the LRU and to memory buf["c"] = c * 2 - assert set(buf.disk) == {"d"} - assert buf.spilled_by_key == {"d": s} - assert buf.spilled_total == s + assert set(buf.slow) == {"d"} + assert buf.slow.weight_by_key == {"d": psize(d)} + assert buf.slow.total_weight == psize(d) # Single key is larger than target and goes directly into slow e = "e" * 500 - slarge = sizeof(e) + buf["e"] = e - assert set(buf.disk) == {"d", "e"} - assert buf.spilled_by_key == {"d": s, "e": slarge} - assert buf.spilled_total == s + slarge + assert set(buf.slow) == {"d", "e"} + assert buf.slow.weight_by_key == {"d": psize(d), "e": psize(e)} + assert buf.slow.total_weight == psize(d, e) # Updating a spilled key with another larger than target updates slow directly - buf["d"] = "d" * 500 - assert set(buf.disk) == {"d", "e"} - assert buf.spilled_by_key == {"d": slarge, "e": slarge} - assert buf.spilled_total == slarge * 2 + d = "d" * 500 + buf["d"] = d + assert set(buf.slow) == {"d", "e"} + assert buf.slow.weight_by_key == {"d": psize(d), "e": psize(e)} + assert buf.slow.total_weight == psize(d, e) + + +requires_zict_210 = pytest.mark.skipif( + parse_version(zict.__version__) <= parse_version("2.0.0"), + reason="requires zict version > 2.0.0", +) + + +@requires_zict_210 +def test_spillbuffer_maxlim(tmpdir): + buf = SpillBuffer(str(tmpdir), target=200, max_spill=600, min_log_interval=0) + + a, b, c, d, e = "a" * 200, "b" * 100, "c" * 99, "d" * 199, "e" * 98 + + # size of a is bigger than target and is smaller than max_spill; + # key should be in slow + buf["a"] = a + assert not buf.fast + assert not buf.fast.weights + assert set(buf.slow) == {"a"} + assert buf.slow.weight_by_key == {"a": psize(a)} + assert buf.slow.total_weight == psize(a) + assert buf["a"] == a + + # size of b is smaller than target key should be in fast + buf["b"] = b + assert set(buf.fast) == {"b"} + assert buf.fast.weights == {"b": sizeof(b)} + assert buf["b"] == b + assert buf.fast.total_weight == sizeof(b) + + # size of c is smaller than target but b+c > target, c should stay in fast and b + # move to slow since the max_spill limit has not been reached yet + + buf["c"] = c + assert set(buf.fast) == {"c"} + assert buf.fast.weights == {"c": sizeof(c)} + assert buf["c"] == c + assert buf.fast.total_weight == sizeof(c) + + assert set(buf.slow) == {"a", "b"} + assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} + assert buf.slow.total_weight == psize(a, b) + + # size of e < target but e+c > target, this will trigger movement of c to slow + # but the max spill limit prevents it. Resulting in e remaining in fast + + with captured_logger(logging.getLogger("distributed.spill")) as logs_e: + buf["e"] = e + + assert "disk reached capacity" in logs_e.getvalue() + + assert set(buf.fast) == {"c", "e"} + assert buf.fast.weights == {"c": sizeof(c), "e": sizeof(e)} + assert buf["e"] == e + assert buf.fast.total_weight == sizeof(c) + sizeof(e) + + assert set(buf.slow) == {"a", "b"} + assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} + assert buf.slow.total_weight == psize(a, b) + + # size of d > target, d should go to slow but slow reached the max_spill limit then + # d will end up on fast with c (which can't be move to slow because it won't fit + # either) + with captured_logger(logging.getLogger("distributed.spill")) as logs_d: + buf["d"] = d + + assert "disk reached capacity" in logs_d.getvalue() + + assert set(buf.fast) == {"c", "d", "e"} + assert buf.fast.weights == {"c": sizeof(c), "d": sizeof(d), "e": sizeof(e)} + assert buf["d"] == d + assert buf.fast.total_weight == sizeof(c) + sizeof(d) + sizeof(e) + + assert set(buf.slow) == {"a", "b"} + assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} + assert buf.slow.total_weight == psize(a, b) + + # Overwrite a key that was in slow, but the size of the new key is larger than + # max_spill + + a_large = "a" * 500 + assert psize(a_large) > 600 # size of max_spill + + with captured_logger(logging.getLogger("distributed.spill")) as logs_alarge: + buf["a"] = a_large + + assert "disk reached capacity" in logs_alarge.getvalue() + + assert set(buf.fast) == {"a", "d", "e"} + assert set(buf.slow) == {"b", "c"} + assert buf.fast.total_weight == sizeof(d) + sizeof(a_large) + sizeof(e) + assert buf.slow.total_weight == psize(b, c) + + # Overwrite a key that was in fast, but the size of the new key is larger than + # max_spill + + d_large = "d" * 501 + with captured_logger(logging.getLogger("distributed.spill")) as logs_dlarge: + buf["d"] = d_large + + assert "disk reached capacity" in logs_dlarge.getvalue() + + assert set(buf.fast) == {"a", "d", "e"} + assert set(buf.slow) == {"b", "c"} + assert buf.fast.total_weight == sizeof(a_large) + sizeof(d_large) + sizeof(e) + assert buf.slow.total_weight == psize(b, c) + + +class MyError(Exception): + pass + + +class Bad: + def __init__(self, size): + self.size = size + + def __getstate__(self): + raise MyError() + + def __sizeof__(self): + return self.size + + +@requires_zict_210 +def test_spillbuffer_fail_to_serialize(tmpdir): + buf = SpillBuffer(str(tmpdir), target=200, max_spill=600, min_log_interval=0) + + # bad data individually larger than spill threshold target 200 + a = Bad(size=201) + + # Exception caught in the worker + with pytest.raises(TypeError, match="Could not serialize"): + with captured_logger(logging.getLogger("distributed.spill")) as logs_bad_key: + buf["a"] = a + + # spill.py must remain silent because we're already logging in worker.py + assert not logs_bad_key.getvalue() + assert not set(buf.fast) + assert not set(buf.slow) + + b = Bad(size=100) # this is small enough to fit in memory/fast + + buf["b"] = b + assert set(buf.fast) == {"b"} + + c = "c" * 100 + with captured_logger(logging.getLogger("distributed.spill")) as logs_bad_key_mem: + # This will go to fast and try to kick b out, + # but keep b in fast since it's not pickable + buf["c"] = c + + # worker.py won't intercept the exception here, so spill.py must dump the traceback + logs_value = logs_bad_key_mem.getvalue() + assert "Failed to pickle" in logs_value # from distributed.spill + assert "Traceback" in logs_value # from distributed.spill + assert set(buf.fast) == {"b", "c"} + assert buf.fast.total_weight == sizeof(b) + sizeof(c) + assert not set(buf.slow) + + +@requires_zict_210 +@pytest.mark.skipif(WINDOWS, reason="Needs chmod") +def test_spillbuffer_oserror(tmpdir): + buf = SpillBuffer(str(tmpdir), target=200, max_spill=800, min_log_interval=0) + + a, b, c, d = ( + "a" * 200, + "b" * 100, + "c" * 201, + "d" * 101, + ) + + # let's have something in fast and something in slow + buf["a"] = a + buf["b"] = b + assert set(buf.fast) == {"b"} + assert set(buf.slow) == {"a"} + + # modify permissions of disk to be read only. + # This causes writes to raise OSError, just like in case of disk full. + os.chmod(tmpdir, 0o555) + + # Add key > than target + with captured_logger(logging.getLogger("distributed.spill")) as logs_oserror_slow: + buf["c"] = c + + assert "Spill to disk failed" in logs_oserror_slow.getvalue() + assert set(buf.fast) == {"b", "c"} + assert set(buf.slow) == {"a"} + + assert buf.slow.weight_by_key == {"a": psize(a)} + assert buf.fast.weights == {"b": sizeof(b), "c": sizeof(c)} + + del buf["c"] + assert set(buf.fast) == {"b"} + assert set(buf.slow) == {"a"} + + # add key to fast which is smaller than target but when added it triggers spill, + # which triggers OSError + with captured_logger(logging.getLogger("distributed.spill")) as logs_oserror_evict: + buf["d"] = d + + assert "Spill to disk failed" in logs_oserror_evict.getvalue() + assert set(buf.fast) == {"b", "d"} + assert set(buf.slow) == {"a"} + + assert buf.slow.weight_by_key == {"a": psize(a)} + assert buf.fast.weights == {"b": sizeof(b), "d": sizeof(d)} + + +@requires_zict_210 +def test_spillbuffer_evict(tmpdir): + buf = SpillBuffer(str(tmpdir), target=300, min_log_interval=0) + + a_bad = Bad(size=100) + a = "a" * 100 + + buf["a"] = a + + assert set(buf.fast) == {"a"} + assert not set(buf.slow) + assert buf.fast.weights == {"a": sizeof(a)} + + # successful eviction + weight = buf.evict() + assert weight == sizeof(a) + + assert not buf.fast + assert set(buf.slow) == {"a"} + assert buf.slow.weight_by_key == {"a": psize(a)} + + buf["a_bad"] = a_bad + + assert set(buf.fast) == {"a_bad"} + assert buf.fast.weights == {"a_bad": sizeof(a_bad)} + assert set(buf.slow) == {"a"} + assert buf.slow.weight_by_key == {"a": psize(a)} + + # unsuccessful eviction + with captured_logger(logging.getLogger("distributed.spill")) as logs_evict_key: + weight = buf.evict() + assert weight == -1 + + assert "Failed to pickle" in logs_evict_key.getvalue() + # bad keys stays in fast + assert set(buf.fast) == {"a_bad"} + assert buf.fast.weights == {"a_bad": sizeof(a_bad)} + assert set(buf.slow) == {"a"} + assert buf.slow.weight_by_key == {"a": psize(a)} diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 71be91215ff..b2a424a640a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -16,6 +16,7 @@ import psutil import pytest +from packaging.version import parse as parse_version from tlz import first, pluck, sliding_window import dask @@ -68,6 +69,17 @@ pytestmark = pytest.mark.ci1 +try: + import zict +except ImportError: + zict = None + +requires_zict = pytest.mark.skipif(not zict, reason="requires zict") +requires_zict_210 = pytest.mark.skipif( + not zict or parse_version(zict.__version__) <= parse_version("2.0.0"), + reason="requires zict version > 2.0.0", +) + @gen_cluster(nthreads=[]) async def test_worker_nthreads(s): @@ -899,58 +911,104 @@ def __sizeof__(self): assert result.data == 123 -@gen_cluster(client=True) -async def test_fail_write_to_disk(c, s, a, b): - class Bad: - def __getstate__(self): - raise TypeError() +class FailToPickle: + def __init__(self, *, reported_size=0, actual_size=0): + self.reported_size = int(reported_size) + self.data = "x" * int(actual_size) + + def __getstate__(self): + raise TypeError() + + def __sizeof__(self): + return self.reported_size - def __sizeof__(self): - return int(100e9) - future = c.submit(Bad) +async def assert_basic_futures(c: Client) -> None: + futures = c.map(inc, range(10)) + results = await c.gather(futures) + assert results == list(map(inc, range(10))) + + +@requires_zict +@gen_cluster(client=True) +async def test_fail_write_to_disk_target_1(c, s, a, b): + """Test failure to spill triggered by key which is individually larger + than target. The data is lost and the task is marked as failed; + the worker remains in usable condition. + """ + future = c.submit(FailToPickle, reported_size=100e9) await wait(future) assert future.status == "error" - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="Could not serialize"): await future - futures = c.map(inc, range(10)) - results = await c._gather(futures) - assert results == list(map(inc, range(10))) + await assert_basic_futures(c) -@pytest.mark.skip(reason="Our logic here is faulty") +@requires_zict @gen_cluster( - nthreads=[("127.0.0.1", 2)], client=True, worker_kwargs={"memory_limit": 10e9} + client=True, + nthreads=[("", 1)], + worker_kwargs=dict( + memory_limit="1 kiB", + memory_target_fraction=0.5, + memory_spill_fraction=False, + memory_pause_fraction=False, + ), ) -async def test_fail_write_many_to_disk(c, s, a): - a.validate = False - await asyncio.sleep(0.1) - assert a.status == Status.running +async def test_fail_write_to_disk_target_2(c, s, a): + """Test failure to spill triggered by key which is individually smaller + than target, so it is not spilled immediately. The data is retained and + the task is NOT marked as failed; the worker remains in usable condition. + """ + x = c.submit(FailToPickle, reported_size=256, key="x") + await wait(x) + assert x.status == "finished" + assert set(a.data.memory) == {"x"} - class Bad: - def __init__(self, x): - pass + y = c.submit(lambda: "y" * 256, key="y") + await wait(y) + assert set(a.data.memory) == {"x", "y"} + assert not a.data.disk - def __getstate__(self): - raise TypeError() + await assert_basic_futures(c) - def __sizeof__(self): - return int(2e9) - futures = c.map(Bad, range(11)) - future = c.submit(lambda *args: 123, *futures) +@requires_zict_210 +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs=dict( + memory_monitor_interval="10ms", + memory_limit="1 kiB", # Spill everything + memory_target_fraction=False, + memory_spill_fraction=0.7, + memory_pause_fraction=False, + ), +) +async def test_fail_write_to_disk_spill(c, s, a): + """Test failure to evict a key, triggered by the spill threshold""" + with captured_logger(logging.getLogger("distributed.spill")) as logs: + bad = c.submit(FailToPickle, actual_size=1_000_000, key="bad") + await wait(bad) - await wait(future) + # Must wait for memory monitor to kick in + while True: + logs_value = logs.getvalue() + if logs_value: + break + await asyncio.sleep(0.01) - with pytest.raises(Exception) as info: - await future + assert "Failed to pickle" in logs_value + assert "Traceback" in logs_value - # workers still operational - result = await c.submit(inc, 1, workers=a.address) - assert result == 2 + # key is in fast + assert bad.status == "finished" + assert bad.key in a.data.fast + + await assert_basic_futures(c) @gen_cluster() @@ -1166,6 +1224,61 @@ async def test_spill_target_threshold(c, s, a): assert set(a.data.disk) == {"y"} +@requires_zict_210 +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs=dict( + memory_limit=1600, + max_spill=600, + memory_target_fraction=0.6, + memory_spill_fraction=False, + memory_pause_fraction=False, + ), +) +async def test_spill_constrained(c, s, w): + """Test distributed.worker.memory.max-spill parameter""" + # spills starts at 1600*0.6=960 bytes of managed memory + + # size in memory ~200; size on disk ~400 + x = c.submit(lambda: "x" * 200, key="x") + await wait(x) + # size in memory ~500; size on disk ~700 + y = c.submit(lambda: "y" * 500, key="y") + await wait(y) + + assert set(w.data) == {x.key, y.key} + assert set(w.data.memory) == {x.key, y.key} + + z = c.submit(lambda: "z" * 500, key="z") + await wait(z) + + assert set(w.data) == {x.key, y.key, z.key} + + # max_spill has not been reached + assert set(w.data.memory) == {y.key, z.key} + assert set(w.data.disk) == {x.key} + + # zb is individually larger than max_spill + zb = c.submit(lambda: "z" * 1700, key="zb") + await wait(zb) + + assert set(w.data.memory) == {y.key, z.key, zb.key} + assert set(w.data.disk) == {x.key} + + del zb + while "zb" in w.data: + await asyncio.sleep(0.01) + + # zc is individually smaller than max_spill, but the evicted key together with + # x it exceeds max_spill + zc = c.submit(lambda: "z" * 500, key="zc") + await wait(zc) + assert set(w.data.memory) == {y.key, z.key, zc.key} + assert set(w.data.disk) == {x.key} + + +@requires_zict @gen_cluster( nthreads=[("", 1)], client=True, @@ -1208,6 +1321,7 @@ def __reduce__(self): await asyncio.sleep(0.01) +@requires_zict @gen_cluster( nthreads=[("", 1)], client=True, diff --git a/distributed/worker.py b/distributed/worker.py index d247388c5f9..bcf96664e40 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -459,6 +459,9 @@ class Worker(ServerNode): memory_pause_fraction: float or False Fraction of memory at which we stop running new tasks (default: read from config key distributed.worker.memory.pause) + max_spill: int, string or False + Limit of number of bytes to be spilled on disk. + (default: read from config key distributed.worker.memory.max-spill) executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload" The executor(s) to use. Depending on the type, it has the following meanings: - Executor instance: The default executor. @@ -577,6 +580,7 @@ class Worker(ServerNode): memory_target_fraction: float | Literal[False] memory_spill_fraction: float | Literal[False] memory_pause_fraction: float | Literal[False] + max_spill: int | Literal[False] data: MutableMapping[str, Any] # {task key: task payload} actors: dict[str, Actor | None] loop: IOLoop @@ -629,6 +633,7 @@ def __init__( memory_target_fraction: float | Literal[False] | None = None, memory_spill_fraction: float | Literal[False] | None = None, memory_pause_fraction: float | Literal[False] | None = None, + max_spill: float | str | Literal[False] | None = None, extensions: list[type] | None = None, metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, startup_information: Mapping[ @@ -874,6 +879,10 @@ def __init__( else dask.config.get("distributed.worker.memory.pause") ) + if max_spill is None: + max_spill = dask.config.get("distributed.worker.memory.max-spill") + self.max_spill = False if max_spill is False else parse_bytes(max_spill) + if isinstance(data, MutableMapping): self.data = data elif callable(data): @@ -893,7 +902,9 @@ def __init__( else: target = sys.maxsize self.data = SpillBuffer( - os.path.join(self.local_directory, "storage"), target=target + os.path.join(self.local_directory, "storage"), + target=target, + max_spill=self.max_spill, ) else: self.data = {} @@ -3730,8 +3741,11 @@ def check_pause(memory): format_bytes(self.memory_limit), ) break - k, v, weight = self.data.fast.evict() - del k, v + weight = self.data.evict() + if weight == -1: + # Failed to evict: disk full, spill size limit exceeded, or pickle error + break + total += weight count += 1 # If the current buffer is filled with a lot of small values, @@ -3748,6 +3762,7 @@ def check_pause(memory): # before trying to evict even more data. self._throttled_gc.collect() memory = proc.memory_info().rss + check_pause(memory) if count: logger.debug(