Skip to content

Commit

Permalink
Typing
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Oct 17, 2023
1 parent 433b82a commit 2f32a23
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 25 deletions.
17 changes: 10 additions & 7 deletions distributed/shuffle/_limiter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import asyncio
from typing import Generic, TypeVar

from distributed.metrics import time

_T = TypeVar("_T", int, None)

class ResourceLimiter:

class ResourceLimiter(Generic[_T]):
"""Limit an abstract resource
This allows us to track usage of an abstract resource. If the usage of this
Expand All @@ -22,15 +25,15 @@ class ResourceLimiter:
await limiter.wait_for_available()
"""

limit: int | None
limit: _T
time_blocked_total: float
time_blocked_avg: float

_acquired: int
_condition: asyncio.Condition
_waiters: int

def __init__(self, limit: int | None = None) -> None:
def __init__(self, limit: _T):
self.limit = limit
self._acquired = 0
self._condition = asyncio.Condition()
Expand All @@ -42,19 +45,19 @@ def __repr__(self) -> str:
return f"<ResourceLimiter limit: {self.limit} available: {self.available}>"

@property
def available(self) -> int | None:
def available(self) -> _T:
"""How far can the value be increased before blocking"""
if self.limit is None:
return None
return self.limit
return max(0, self.limit - self._acquired)

@property
def full(self) -> bool:
"""Return True if the limit has been reached"""
return self.available is None or bool(self.available)
return self.available is not None and not self.available

@property
def free(self) -> bool:
def empty(self) -> bool:
"""Return True if nothing has been acquired / the limiter is in a neutral state"""
return self._acquired == 0

Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def test_memory_limit(big_payloads):

many_small = [asyncio.create_task(buf.write(small_payload)) for _ in range(11)]
assert buf.memory_limiter
while buf.memory_limiter.available():
while buf.memory_limiter.available:
await asyncio.sleep(0.1)

new_put = asyncio.create_task(buf.write(small_payload))
Expand All @@ -80,7 +80,7 @@ async def test_memory_limit(big_payloads):
many_small = asyncio.gather(*many_small)
await new_put

while not buf.memory_limiter.free():
while not buf.memory_limiter.empty:
await asyncio.sleep(0.1)
buf.allow_process.clear()
big_tasks = [
Expand Down
14 changes: 8 additions & 6 deletions distributed/shuffle/tests/test_comm_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def test_basic(tmp_path):
async def send(address, shards):
d[address].extend(shards)

mc = CommShardsBuffer(send=send)
mc = CommShardsBuffer(send=send, memory_limiter=ResourceLimiter(None))
await mc.write({"x": b"0" * 1000, "y": b"1" * 500})
await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

Expand All @@ -37,7 +37,7 @@ async def test_exceptions(tmp_path):
async def send(address, shards):
raise Exception(123)

mc = CommShardsBuffer(send=send)
mc = CommShardsBuffer(send=send, memory_limiter=ResourceLimiter(None))
await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

while not mc._exception:
Expand All @@ -63,7 +63,9 @@ async def send(address, shards):
d[address].extend(shards)
sending_first.set()

mc = CommShardsBuffer(send=send, concurrency_limit=1)
mc = CommShardsBuffer(
send=send, concurrency_limit=1, memory_limiter=ResourceLimiter(None)
)
await mc.write({"x": b"0", "y": b"1"})
await mc.write({"x": b"0", "y": b"1"})
flush_task = asyncio.create_task(mc.flush())
Expand Down Expand Up @@ -96,7 +98,7 @@ async def send(address, shards):
send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB"))
)
payload = {
x: gen_bytes(frac, comm_buffer.memory_limiter._maxvalue) for x in range(nshards)
x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards)
}

async with comm_buffer as mc:
Expand All @@ -113,7 +115,7 @@ async def send(address, shards):
assert len(d) == 10
assert (
sum(map(len, d[0]))
== len(gen_bytes(frac, comm_buffer.memory_limiter._maxvalue)) * nputs
== len(gen_bytes(frac, comm_buffer.memory_limiter.limit)) * nputs
)


Expand All @@ -137,7 +139,7 @@ async def send(address, shards):
send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB"))
)
payload = {
x: gen_bytes(frac, comm_buffer.memory_limiter._maxvalue) for x in range(nshards)
x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards)
}

async with comm_buffer as mc:
Expand Down
20 changes: 14 additions & 6 deletions distributed/shuffle/tests/test_disk_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test


Expand All @@ -20,7 +21,9 @@ def read_bytes(path: Path) -> tuple[bytes, int]:

@gen_test()
async def test_basic(tmp_path):
async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})

Expand All @@ -41,7 +44,9 @@ async def test_basic(tmp_path):
@gen_test()
async def test_read_before_flush(tmp_path):
payload = {"1": b"foo"}
async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
with pytest.raises(RuntimeError):
mf.read(1)

Expand All @@ -59,7 +64,9 @@ async def test_read_before_flush(tmp_path):
@pytest.mark.parametrize("count", [2, 100, 1000])
@gen_test()
async def test_many(tmp_path, count):
async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
d = {i: str(i).encode() * 100 for i in range(count)}

for _ in range(10):
Expand All @@ -84,7 +91,9 @@ async def _process(self, *args: Any, **kwargs: Any) -> None:

@gen_test()
async def test_exceptions(tmp_path):
async with BrokenDiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf:
async with BrokenDiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

while not mf._exception:
Expand Down Expand Up @@ -114,8 +123,7 @@ async def test_high_pressure_flush_with_exception(tmp_path):
payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)}

async with EventuallyBrokenDiskShardsBuffer(
directory=tmp_path,
read=read_bytes,
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
tasks = []
for _ in range(10):
Expand Down
8 changes: 4 additions & 4 deletions distributed/shuffle/tests/test_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ async def test_limiter_basic():

@gen_test()
async def test_unlimited_limiter():
res = ResourceLimiter()
res = ResourceLimiter(None)

assert res.free
assert res.empty
assert res.available is None
assert not res.full

res.increase(3)
assert not res.free
assert not res.empty
assert res.available is None
assert not res.full

res.increase(2**40)
assert not res.free
assert not res.empty
assert res.available is None
assert not res.full

Expand Down

0 comments on commit 2f32a23

Please sign in to comment.