Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ResourceLimiter to be unlimited #8276

Merged
merged 4 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4459,7 +4459,9 @@
for prefix in ["comm", "disk"]:
data[f"{prefix}_total"].append(d[prefix]["total"])
data[f"{prefix}_memory"].append(d[prefix]["memory"])
data[f"{prefix}_memory_limit"].append(d[prefix]["memory_limit"])
data[f"{prefix}_memory_limit"].append(

Check warning on line 4462 in distributed/dashboard/components/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/dashboard/components/scheduler.py#L4462

Added line #L4462 was not covered by tests
d[prefix]["memory_limit"] or 0
)
data[f"{prefix}_buckets"].append(d[prefix]["buckets"])
data[f"{prefix}_avg_duration"].append(
d[prefix]["diagnostics"].get("avg_duration", 0)
Expand Down
15 changes: 6 additions & 9 deletions distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ShardsBuffer(Generic[ShardType]):
shards: defaultdict[str, _List[ShardType]]
sizes: defaultdict[str, int]
concurrency_limit: int
memory_limiter: ResourceLimiter | None
memory_limiter: ResourceLimiter
diagnostics: dict[str, float]
max_message_size: int

Expand All @@ -64,7 +64,7 @@ class ShardsBuffer(Generic[ShardType]):

def __init__(
self,
memory_limiter: ResourceLimiter | None,
memory_limiter: ResourceLimiter,
concurrency_limit: int = 2,
max_message_size: int = -1,
) -> None:
Expand Down Expand Up @@ -97,7 +97,7 @@ def heartbeat(self) -> dict[str, Any]:
"written": self.bytes_written,
"read": self.bytes_read,
"diagnostics": self.diagnostics,
"memory_limit": self.memory_limiter._maxvalue if self.memory_limiter else 0,
"memory_limit": self.memory_limiter.limit,
}

async def process(self, id: str, shards: list[ShardType], size: int) -> None:
Expand All @@ -119,8 +119,7 @@ async def process(self, id: str, shards: list[ShardType], size: int) -> None:
"avg_duration"
] + 0.02 * (stop - start)
finally:
if self.memory_limiter:
await self.memory_limiter.decrease(size)
await self.memory_limiter.decrease(size)
self.bytes_memory -= size

async def _process(self, id: str, shards: list[ShardType]) -> None:
Expand Down Expand Up @@ -198,15 +197,13 @@ async def write(self, data: dict[str, ShardType]) -> None:
self.bytes_memory += total_batch_size
self.bytes_total += total_batch_size

if self.memory_limiter:
self.memory_limiter.increase(total_batch_size)
self.memory_limiter.increase(total_batch_size)
async with self._shards_available:
for worker, shard in data.items():
self.shards[worker].append(shard)
self.sizes[worker] += sizes[worker]
self._shards_available.notify()
if self.memory_limiter:
await self.memory_limiter.wait_for_available()
await self.memory_limiter.wait_for_available()
del data
assert total_batch_size

Expand Down
13 changes: 6 additions & 7 deletions distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ class CommShardsBuffer(ShardsBuffer):
How to send a list of shards to a worker
Expects an address of the target worker (string)
and a payload of shards (list of bytes) to send to that worker
memory_limiter : ResourceLimiter, optional
Limiter for memory usage (in bytes), or None if no limiting
should be applied. If the incoming data that has yet to be
processed exceeds this limit, then the buffer will block until
below the threshold. See :meth:`.write` for the implementation
of this scheme.
memory_limiter : ResourceLimiter
Limiter for memory usage (in bytes). If the incoming data that
has yet to be processed exceeds this limit, then the buffer will
block until below the threshold. See :meth:`.write` for the
implementation of this scheme.
concurrency_limit : int
Number of background tasks to run.
"""
Expand All @@ -54,7 +53,7 @@ class CommShardsBuffer(ShardsBuffer):
def __init__(
self,
send: Callable[[str, list[tuple[Any, bytes]]], Awaitable[None]],
memory_limiter: ResourceLimiter | None = None,
memory_limiter: ResourceLimiter,
concurrency_limit: int = 10,
):
super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class DiskShardsBuffer(ShardsBuffer):
----------
directory : str or pathlib.Path
Where to write and read data. Ideally points to fast disk.
memory_limiter : ResourceLimiter, optional
memory_limiter : ResourceLimiter
Limiter for in-memory buffering (at most this much data)
before writes to disk occur. If the incoming data that has yet
to be processed exceeds this limit, then the buffer will block
Expand All @@ -122,7 +122,7 @@ def __init__(
self,
directory: str | pathlib.Path,
read: Callable[[pathlib.Path], tuple[Any, int]],
memory_limiter: ResourceLimiter | None = None,
memory_limiter: ResourceLimiter,
):
super().__init__(
memory_limiter=memory_limiter,
Expand Down
44 changes: 32 additions & 12 deletions distributed/shuffle/_limiter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import asyncio
from typing import Generic, TypeVar

from distributed.metrics import time

_T = TypeVar("_T", int, None)

class ResourceLimiter:

class ResourceLimiter(Generic[_T]):
"""Limit an abstract resource

This allows us to track usage of an abstract resource. If the usage of this
resources goes beyond a defined maxvalue, we can block further execution
resources goes beyond a defined limit, we can block further execution

Example::

Expand All @@ -18,39 +21,56 @@ class ResourceLimiter:
limiter.increase(2)
limiter.decrease(1)

# This will block since we're still not below maxvalue
# This will block since we're still not below limit
await limiter.wait_for_available()
"""

def __init__(self, maxvalue: int) -> None:
self._maxvalue = maxvalue
limit: _T
time_blocked_total: float
time_blocked_avg: float

_acquired: int
_condition: asyncio.Condition
_waiters: int

def __init__(self, limit: _T):
self.limit = limit
self._acquired = 0
self._condition = asyncio.Condition()
self._waiters = 0
self.time_blocked_total = 0.0
self.time_blocked_avg = 0.0

def __repr__(self) -> str:
return f"<ResourceLimiter maxvalue: {self._maxvalue} available: {self.available()}>"
return f"<ResourceLimiter limit: {self.limit} available: {self.available}>"

def available(self) -> int:
@property
def available(self) -> _T:
"""How far can the value be increased before blocking"""
return max(0, self._maxvalue - self._acquired)
if self.limit is None:
return self.limit
return max(0, self.limit - self._acquired)

@property
def full(self) -> bool:
"""Return True if the limit has been reached"""
return self.available is not None and not self.available

def free(self) -> bool:
@property
def empty(self) -> bool:
"""Return True if nothing has been acquired / the limiter is in a neutral state"""
return self._acquired == 0

async def wait_for_available(self) -> None:
"""Block until the counter drops below maxvalue"""
"""Block until the counter drops below limit"""
start = time()
duration = 0.0
try:
if self.available():
if not self.full:
return
async with self._condition:
self._waiters += 1
await self._condition.wait_for(self.available)
await self._condition.wait_for(lambda: not self.full)
self._waiters -= 1
duration = time() - start
finally:
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def test_memory_limit(big_payloads):

many_small = [asyncio.create_task(buf.write(small_payload)) for _ in range(11)]
assert buf.memory_limiter
while buf.memory_limiter.available():
while buf.memory_limiter.available:
await asyncio.sleep(0.1)

new_put = asyncio.create_task(buf.write(small_payload))
Expand All @@ -80,7 +80,7 @@ async def test_memory_limit(big_payloads):
many_small = asyncio.gather(*many_small)
await new_put

while not buf.memory_limiter.free():
while not buf.memory_limiter.empty:
await asyncio.sleep(0.1)
buf.allow_process.clear()
big_tasks = [
Expand Down
14 changes: 8 additions & 6 deletions distributed/shuffle/tests/test_comm_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def test_basic(tmp_path):
async def send(address, shards):
d[address].extend(shards)

mc = CommShardsBuffer(send=send)
mc = CommShardsBuffer(send=send, memory_limiter=ResourceLimiter(None))
await mc.write({"x": b"0" * 1000, "y": b"1" * 500})
await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

Expand All @@ -37,7 +37,7 @@ async def test_exceptions(tmp_path):
async def send(address, shards):
raise Exception(123)

mc = CommShardsBuffer(send=send)
mc = CommShardsBuffer(send=send, memory_limiter=ResourceLimiter(None))
await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

while not mc._exception:
Expand All @@ -63,7 +63,9 @@ async def send(address, shards):
d[address].extend(shards)
sending_first.set()

mc = CommShardsBuffer(send=send, concurrency_limit=1)
mc = CommShardsBuffer(
send=send, concurrency_limit=1, memory_limiter=ResourceLimiter(None)
)
await mc.write({"x": b"0", "y": b"1"})
await mc.write({"x": b"0", "y": b"1"})
flush_task = asyncio.create_task(mc.flush())
Expand Down Expand Up @@ -96,7 +98,7 @@ async def send(address, shards):
send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB"))
)
payload = {
x: gen_bytes(frac, comm_buffer.memory_limiter._maxvalue) for x in range(nshards)
x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards)
}

async with comm_buffer as mc:
Expand All @@ -113,7 +115,7 @@ async def send(address, shards):
assert len(d) == 10
assert (
sum(map(len, d[0]))
== len(gen_bytes(frac, comm_buffer.memory_limiter._maxvalue)) * nputs
== len(gen_bytes(frac, comm_buffer.memory_limiter.limit)) * nputs
)


Expand All @@ -137,7 +139,7 @@ async def send(address, shards):
send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB"))
)
payload = {
x: gen_bytes(frac, comm_buffer.memory_limiter._maxvalue) for x in range(nshards)
x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards)
}

async with comm_buffer as mc:
Expand Down
20 changes: 14 additions & 6 deletions distributed/shuffle/tests/test_disk_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test


Expand All @@ -20,7 +21,9 @@ def read_bytes(path: Path) -> tuple[bytes, int]:

@gen_test()
async def test_basic(tmp_path):
async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})

Expand All @@ -41,7 +44,9 @@ async def test_basic(tmp_path):
@gen_test()
async def test_read_before_flush(tmp_path):
payload = {"1": b"foo"}
async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
with pytest.raises(RuntimeError):
mf.read(1)

Expand All @@ -59,7 +64,9 @@ async def test_read_before_flush(tmp_path):
@pytest.mark.parametrize("count", [2, 100, 1000])
@gen_test()
async def test_many(tmp_path, count):
async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
d = {i: str(i).encode() * 100 for i in range(count)}

for _ in range(10):
Expand All @@ -84,7 +91,9 @@ async def _process(self, *args: Any, **kwargs: Any) -> None:

@gen_test()
async def test_exceptions(tmp_path):
async with BrokenDiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with BrokenDiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

while not mf._exception:
Expand Down Expand Up @@ -114,8 +123,7 @@ async def test_high_pressure_flush_with_exception(tmp_path):
payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)}

async with EventuallyBrokenDiskShardsBuffer(
directory=tmp_path,
read=read_bytes,
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
tasks = []
for _ in range(10):
Expand Down
Loading
Loading