Skip to content

Commit

Permalink
[Data] [1/n] Async iter_batches: Add Threadpool util (ray-project#3…
Browse files Browse the repository at this point in the history
…3575)

Part 1 of async iter_batches support. Converts the existing _make_async_gen util to be able to use a threadpool instead of a single thread.

---------

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Jack He <jackhe2345@gmail.com>
  • Loading branch information
amogkam authored and ProjectsByJackHe committed May 4, 2023
1 parent d09f345 commit 2b4b9ad
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 135 deletions.
7 changes: 7 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ py_library(
deps = ["//python/ray/tests:conftest"],
)

py_test_module_list(
files = glob(["tests/block_batching/test_*.py"]),
size = "medium",
tags = ["team:ml", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test_module_list(
files = glob(["tests/preprocessors/test_*.py"]),
size = "small",
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/_internal/block_batching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ray.data._internal.block_batching.block_batching import (
batch_blocks,
batch_block_refs,
)

__all__ = ["batch_blocks", "batch_block_refs"]
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import collections
import itertools
import queue
import sys
import threading
from typing import Any, Callable, Iterator, Optional, TypeVar, Union

import ray
from ray.actor import ActorHandle
from ray.data._internal.block_batching.util import _make_async_gen
from ray.data._internal.batcher import Batcher, ShufflingBatcher
from ray.data._internal.stats import DatasetPipelineStats, DatasetStats
from ray.data._internal.memory_tracing import trace_deallocation
Expand Down Expand Up @@ -37,7 +36,7 @@ def batch_block_refs(
prefetch_blocks: int = 0,
clear_block_after_read: bool = False,
batch_size: Optional[int] = None,
batch_format: Optional[str] = "default",
batch_format: str = "default",
drop_last: bool = False,
collate_fn: Optional[Callable[[DataBatch], Any]] = None,
shuffle_buffer_min_size: Optional[int] = None,
Expand Down Expand Up @@ -129,7 +128,7 @@ def batch_blocks(
*,
stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None,
batch_size: Optional[int] = None,
batch_format: Optional[str] = "default",
batch_format: str = "default",
drop_last: bool = False,
collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None,
shuffle_buffer_min_size: Optional[int] = None,
Expand All @@ -144,84 +143,43 @@ def batch_blocks(
This means that this function does not support block prefetching.
"""

batch_iter = _format_batches(
_blocks_to_batches(
block_iter=blocks,
def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]:
batch_iter = _format_batches(
_blocks_to_batches(
block_iter=base_iterator,
stats=stats,
batch_size=batch_size,
drop_last=drop_last,
shuffle_buffer_min_size=shuffle_buffer_min_size,
shuffle_seed=shuffle_seed,
ensure_copy=ensure_copy,
),
batch_format=batch_format,
stats=stats,
batch_size=batch_size,
drop_last=drop_last,
shuffle_buffer_min_size=shuffle_buffer_min_size,
shuffle_seed=shuffle_seed,
ensure_copy=ensure_copy,
),
batch_format=batch_format,
stats=stats,
)
)

if collate_fn is not None:
if collate_fn is not None:

def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]:
for batch in iterator:
yield collate_fn(batch)
def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]:
for batch in iterator:
yield collate_fn(batch)

batch_iter = batch_fn_iter(batch_iter)
batch_iter = batch_fn_iter(batch_iter)
yield from batch_iter

if prefetch_batches > 0:
batch_iter = _make_async_gen(batch_iter, prefetch_buffer_size=prefetch_batches)
batch_iter = _make_async_gen(
blocks, fn=_iterator_fn, num_workers=prefetch_batches
)
else:
batch_iter = _iterator_fn(blocks)

for formatted_batch in batch_iter:
user_timer = stats.iter_user_s.timer() if stats else nullcontext()
with user_timer:
yield formatted_batch


def _make_async_gen(
base_iterator: Iterator[T], prefetch_buffer_size: int = 1
) -> Iterator[T]:
"""Returns a new iterator with elements fetched from the base_iterator
in an async fashion using a background thread.
Args:
base_iterator: The iterator to asynchronously fetch from.
prefetch_buffer_size: The maximum number of items to prefetch. Increasing the
size allows for more computation overlap for very expensive downstream UDFs.
However it comes at the cost of additional memory overhead. Defaults to 1.
Returns:
An iterator with the same elements as the base_iterator.
"""

fetch_queue = queue.Queue(maxsize=prefetch_buffer_size)

sentinel = object()

def _async_fetch():
for item in base_iterator:
fetch_queue.put(item, block=True)

# Indicate done adding items.
fetch_queue.put(sentinel, block=True)

# Start a background thread which iterates through the base iterator,
# triggering execution and adding results to the queue until it is full.
# Iterating through the iterator returned by this function pulls
# ready items from the queue, allowing the background thread to continue execution.

fetch_thread = threading.Thread(target=_async_fetch)
fetch_thread.start()

while True:
next_item = fetch_queue.get(block=True)
if next_item is not sentinel:
yield next_item
fetch_queue.task_done()
if next_item is sentinel:
break

fetch_queue.join()
fetch_thread.join()


def _resolve_blocks(
block_ref_iter: Iterator[ObjectRef[Block]],
stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None,
Expand Down Expand Up @@ -385,7 +343,7 @@ def get_iter_next_batch_s_timer():

def _format_batches(
block_iter: Iterator[Block],
batch_format: Optional[str],
batch_format: str,
stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None,
) -> Iterator[DataBatch]:
"""Given an iterator of blocks, returns an iterator of formatted batches.
Expand Down
90 changes: 90 additions & 0 deletions python/ray/data/_internal/block_batching/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging
import queue
import threading
from typing import Callable, Iterator, TypeVar

T = TypeVar("T")
U = TypeVar("U")

logger = logging.getLogger(__name__)


def _make_async_gen(
base_iterator: Iterator[T],
fn: Callable[[Iterator[T]], Iterator[U]],
num_workers: int = 1,
) -> Iterator[U]:
"""Returns a new iterator with elements fetched from the base_iterator
in an async fashion using a threadpool.
Each thread in the threadpool will fetch data from the base_iterator in a
thread-safe fashion, and apply the provided `fn` computation concurrently.
Args:
base_iterator: The iterator to asynchronously fetch from.
fn: The function to run on the input iterator.
num_workers: The number of threads to use in the threadpool.
Returns:
An iterator with the same elements as outputted from `fn`.
"""

# Use a lock to fetch from the base_iterator in a thread-safe fashion.
def convert_to_threadsafe_iterator(base_iterator: Iterator[T]) -> Iterator[T]:
class ThreadSafeIterator:
def __init__(self, it):
self.lock = threading.Lock()
self.it = it

def __next__(self):
with self.lock:
return next(self.it)

def __iter__(self):
return self

return ThreadSafeIterator(base_iterator)

thread_safe_generator = convert_to_threadsafe_iterator(base_iterator)

class Sentinel:
def __init__(self, thread_index: int):
self.thread_index = thread_index

output_queue = queue.Queue(1)

# Because pulling from the base iterator cannot happen concurrently,
# we must execute the expensive computation in a separate step which
# can be parallelized via a threadpool.
def execute_computation(thread_index: int):
try:
for item in fn(thread_safe_generator):
output_queue.put(item, block=True)
output_queue.put(Sentinel(thread_index), block=True)
except Exception as e:
output_queue.put(e, block=True)

threads = [
threading.Thread(target=execute_computation, args=(i,), daemon=True)
for i in range(num_workers)
]

for thread in threads:
thread.start()

num_threads_finished = 0
while True:
next_item = output_queue.get(block=True)
if isinstance(next_item, Exception):
output_queue.task_done()
raise next_item
if isinstance(next_item, Sentinel):
output_queue.task_done()
logger.debug(f"Thread {next_item.thread_index} finished.")
num_threads_finished += 1
threads[next_item.thread_index].join()
else:
yield next_item
output_queue.task_done()
if num_threads_finished >= num_workers:
break
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import pyarrow as pa

from ray.data.block import Block
from ray.data._internal.block_batching import (
from ray.data._internal.block_batching.block_batching import (
BlockPrefetcher,
batch_block_refs,
batch_blocks,
_prefetch_blocks,
_blocks_to_batches,
_format_batches,
_make_async_gen,
)


Expand All @@ -26,9 +25,9 @@ def block_generator(num_rows: int, num_blocks: int):

def test_batch_block_refs():
with mock.patch(
"ray.data._internal.block_batching._prefetch_blocks"
"ray.data._internal.block_batching.block_batching._prefetch_blocks"
) as mock_prefetch, mock.patch(
"ray.data._internal.block_batching.batch_blocks"
"ray.data._internal.block_batching.block_batching.batch_blocks"
) as mock_batch_blocks:
block_iter = block_generator(2, 2)
batch_iter = batch_block_refs(block_iter)
Expand All @@ -40,9 +39,9 @@ def test_batch_block_refs():

def test_batch_blocks():
with mock.patch(
"ray.data._internal.block_batching._blocks_to_batches"
"ray.data._internal.block_batching.block_batching._blocks_to_batches"
) as mock_batch, mock.patch(
"ray.data._internal.block_batching._format_batches"
"ray.data._internal.block_batching.block_batching._format_batches"
) as mock_format:
block_iter = block_generator(2, 2)
batch_iter = batch_blocks(block_iter)
Expand Down Expand Up @@ -123,64 +122,6 @@ def test_format_batches(batch_format):
assert isinstance(batch["foo"], np.ndarray)


def test_make_async_gen():
"""Tests that make_async_gen overlaps compute."""

num_items = 10

def gen():
for i in range(num_items):
time.sleep(2)
yield i

def sleep_udf(item):
time.sleep(3)
return item

iterator = _make_async_gen(gen())

start_time = time.time()
outputs = []
for item in iterator:
outputs.append(sleep_udf(item))
end_time = time.time()

assert outputs == list(range(num_items))

assert end_time - start_time < num_items * 3 + 3


def test_make_async_gen_buffer_size():
"""Tests that multiple items can be prefetched at a time
with larger buffer size."""

num_items = 5

def gen():
for i in range(num_items):
time.sleep(1)
yield i

def sleep_udf(item):
time.sleep(5)
return item

iterator = _make_async_gen(gen(), prefetch_buffer_size=4)

start_time = time.time()

# Only sleep for first item.
sleep_udf(next(iterator))

# All subsequent items should already be prefetched and should be ready.
for _ in iterator:
pass
end_time = time.time()

# 1 second for first item, 5 seconds for udf, 0.5 seconds buffer
assert end_time - start_time < 6.5


# Test for 3 cases
# 1. Batch size is less than block size
# 2. Batch size is more than block size
Expand All @@ -195,7 +136,8 @@ def sleep_batch_format(batch_iter, *args, **kwargs):
yield batch

with mock.patch(
"ray.data._internal.block_batching._format_batches", sleep_batch_format
"ray.data._internal.block_batching.block_batching._format_batches",
sleep_batch_format,
):
batch_iter = batch_blocks(
batch_size=batch_size, blocks=blocks, prefetch_batches=1
Expand Down
Loading

0 comments on commit 2b4b9ad

Please sign in to comment.