Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility of spill to constrained disk [WIP] #5521

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,17 @@ properties:
description: >-
When the process memory reaches this level the nanny process will kill
the worker (if a nanny is present)

spill-limit:
oneOf:
- {type: string}
- {enum: [false]}
description: |
A limit on the spilling to disk, after this limit is hit we stop writing to disk.

http:
type: object
decription: Settings for Dask's embedded HTTP Server
description: Settings for Dask's embedded HTTP Server
properties:
routes:
type: array
Expand Down
3 changes: 3 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ distributed:
pause: 0.80 # fraction at which we pause worker threads
terminate: 0.95 # fraction at which we terminate the worker

#spill-limit to disk
spill-limit: False

http:
routes:
- distributed.http.worker.prometheus
Expand Down
51 changes: 46 additions & 5 deletions distributed/spill.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

from collections.abc import Hashable, Mapping
from functools import partial
from ctypes import sizeof # noqa: F401 do we use this?
from functools import ( # noqa: F401 do we use this total ordering
partial,
total_ordering,
)
from typing import Any

from typing_extensions import Literal
from zict import Buffer, File, Func

from .protocol import deserialize_bytes, serialize_bytelist
Expand All @@ -18,17 +23,25 @@ class SpillBuffer(Buffer):
spilled_by_key: dict[Hashable, int]
spilled_total: int

def __init__(self, spill_directory: str, target: int):
def __init__(
self,
spill_directory: str,
target: int,
disk_limit: int | Literal[False] | None = None,
):
self.spilled_by_key = {}
self.spilled_total = 0
storage = Func(

self.disk_limit = disk_limit
self.spilled_total_disk = 0 # MAYBE CHOOSE A DIFFERENT NAME
self.storage = Func(
partial(serialize_bytelist, on_error="raise"),
deserialize_bytes,
File(spill_directory),
)
super().__init__(
{},
storage,
self.storage,
target,
weight=self._weight,
fast_to_slow_callbacks=[self._on_evict],
Expand All @@ -51,6 +64,16 @@ def disk(self) -> Mapping[Hashable, Any]:

@staticmethod
def _weight(key: Hashable, value: Any) -> int:
# Disk limit will be false by default so we need to check we have a limit
# otherwise the second condition is always true
# this triggers the right path but will record -1 on the tracking of what's
# on fast so not really working
# if self.disk_limit and (
# safe_sizeof(value) + self.spilled_total_disk > self.disk_limit
# ):
# print("spill-limit reached keeping task in memory")
# return -1 # this should keep the key in fast
# else:
return safe_sizeof(value)

def _on_evict(self, key: Hashable, value: Any) -> None:
Expand All @@ -63,14 +86,32 @@ def _on_retrieve(self, key: Hashable, value: Any) -> None:

def __setitem__(self, key: Hashable, value: Any) -> None:
self.spilled_total -= self.spilled_by_key.pop(key, 0)
super().__setitem__(key, value)
# super().__setitem__(key, value)

if self.weight(key, value) <= self.n or (
self.disk_limit
and (safe_sizeof(value) + self.spilled_total_disk > self.disk_limit)
):
print("im here")
if key in self.slow:
del self.slow[key]
self.fast[key] = value
else:
if key in self.fast:
del self.fast[key]
self.slow[key] = value

if key in self.slow:
# value is individually larger than target so it went directly to slow.
# _on_evict was not called.
b = safe_sizeof(value)
self.spilled_by_key[key] = b
self.spilled_total += b

if self.disk_limit:
# track total spilled to disk (on disk) if limit is provided
self.spilled_total_disk += len(self.storage.d.get(key))

def __delitem__(self, key: Hashable) -> None:
self.spilled_total -= self.spilled_by_key.pop(key, 0)
super().__delitem__(key)
33 changes: 33 additions & 0 deletions distributed/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,36 @@ def test_spillbuffer(tmpdir):
assert set(buf.disk) == {"d", "e"}
assert buf.spilled_by_key == {"d": slarge, "e": slarge}
assert buf.spilled_total == slarge * 2


def test_spillbuffer_disk_limit(tmpdir):
buf = SpillBuffer(str(tmpdir), target=200, disk_limit=500)

# Convenience aliases
assert buf.memory is buf.fast
assert buf.disk is buf.slow

assert not buf.spilled_by_key
assert buf.spilled_total == 0
assert buf.spilled_total_disk == 0

a, b, c = "a" * 100, "b" * 200, "c" * 200

s = sizeof(b)

buf["a"] = a
assert not buf.disk
assert not buf.spilled_by_key
assert buf.spilled_total == buf.spilled_total_disk == 0
assert set(buf.memory) == {"a"}

buf["b"] = b
assert set(buf.disk) == {"b"}
assert buf.spilled_by_key == {"b": s}
assert buf.spilled_total == s
assert buf.spilled_total_disk == len(buf.storage.d.get("b"))

# add a key that will go over the disk limit, should keep it in fast
buf["c"] = c
assert set(buf.memory) == {"a", "c"}
# this works but the count of what is in fast is off. since this will sum a -1 but
12 changes: 12 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ class Worker(ServerNode):
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
spill_limit: int, string or False (### NOT SURE WHAT TYPE YET)
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.spill-limit)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
Expand Down Expand Up @@ -512,6 +515,7 @@ class Worker(ServerNode):
memory_target_fraction: float | Literal[False]
memory_spill_fraction: float | Literal[False]
memory_pause_fraction: float | Literal[False]
spill_limit: int | Literal[False]
data: MutableMapping[str, Any] # {task key: task payload}
actors: dict[str, Actor | None]
loop: IOLoop
Expand Down Expand Up @@ -563,6 +567,7 @@ def __init__(
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
spill_limit: str | Literal[False] | None = None,
extensions: list[type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
Expand Down Expand Up @@ -811,6 +816,12 @@ def __init__(
else dask.config.get("distributed.worker.memory.pause")
)

self.spill_limit = (
parse_bytes(spill_limit)
if spill_limit is not None
else dask.config.get("distributed.worker.memory.spill-limit")
)

if isinstance(data, MutableMapping):
self.data = data
elif callable(data):
Expand All @@ -829,6 +840,7 @@ def __init__(
* (self.memory_target_fraction or self.memory_spill_fraction)
)
or sys.maxsize,
disk_limit=self.spill_limit,
)
else:
self.data = {}
Expand Down