Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Refactor block batching to follow iterator pattern #31425

Merged
merged 17 commits into from
Jan 10, 2023

Conversation

amogkam
Copy link
Contributor

@amogkam amogkam commented Jan 4, 2023

As discussed offline with @clarkzinzow (#30190 (comment)), this PR refactors block batching to follow a chained iterators pattern.

This allows for more flexibility, composability, and better testing of components upstream of Iterator[Block] (formatting, shuffling, batching, prefetching).

This PR only does a refactor and adds tests. There are no API or functionality changes in this PR. This PR also consolidates the map_batches and iter_batches codepaths.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
@amogkam amogkam added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Jan 5, 2023
@amogkam amogkam changed the title [Data] Refactor block batching to follow iterator format [Data] Refactor block batching to follow iterator pattern Jan 5, 2023
Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good overall, main note is about using collections.deque rather than queue.Queue for the sliding prefetch window.

Comment on lines 218 to 239
sliding_window = queue.Queue(maxsize=window_size)

# Create the initial set of blocks to prefetch.
while not sliding_window.full():
try:
sliding_window.put(next(block_ref_iter))
except StopIteration:
break
with stats.iter_wait_s.timer() if stats else nullcontext():
prefetcher.prefetch_blocks(list(sliding_window.queue))

while not sliding_window.empty():
block_ref = sliding_window.get()
try:
sliding_window.put(next(block_ref_iter))
with stats.iter_wait_s.timer() if stats else nullcontext():
prefetcher.prefetch_blocks(list(sliding_window.queue))
except StopIteration:
pass
yield block_ref
if clear_block_after_read:
ray._private.internal_api.free(block_ref, local_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we'd want to stick to collections.deque for a single-threaded sliding window implementation (more efficient, less complicated semantics):

Suggested change
sliding_window = queue.Queue(maxsize=window_size)
# Create the initial set of blocks to prefetch.
while not sliding_window.full():
try:
sliding_window.put(next(block_ref_iter))
except StopIteration:
break
with stats.iter_wait_s.timer() if stats else nullcontext():
prefetcher.prefetch_blocks(list(sliding_window.queue))
while not sliding_window.empty():
block_ref = sliding_window.get()
try:
sliding_window.put(next(block_ref_iter))
with stats.iter_wait_s.timer() if stats else nullcontext():
prefetcher.prefetch_blocks(list(sliding_window.queue))
except StopIteration:
pass
yield block_ref
if clear_block_after_read:
ray._private.internal_api.free(block_ref, local_only=False)
sliding_window = collections.deque(
itertools.islice(block_ref_iter, window_size), maxsize=window_size
)
while sliding_window:
block_ref = sliding_window.popleft()
try:
sliding_window.append(next(block_ref_iter))
with stats.iter_wait_s.timer() if stats else nullcontext():
prefetcher.prefetch_blocks(list(sliding_window))
except StopIteration:
pass
yield block_ref
if clear_block_after_read:
ray._private.internal_api.free(block_ref, local_only=False)

Even after we have a background thread worker, I don't think that we'd want to have a multithreading queue inside of _prefetch_blocks. We can keep each of these "batch preprocessing" generators threading-agnostic by pushing the producer generator into the background thread and wiring up a multithreading queue between the producer generator and the consumer generator, which should be a lot cleaner and easier to evolve (and easier to test).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated

sliding_window = queue.Queue(maxsize=window_size)

# Create the initial set of blocks to prefetch.
while not sliding_window.full():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be noted that this LBYL pattern is not thread-safe in multithreaded code, since the Queue class makes no guarantees that a subsequent put() will not block even if full() returns False. https://docs.python.org/3/library/queue.html#queue.Queue.full

The more idiomatic/correct pattern is EAFP, where you try to sliding_window.put_nowait() and catch a queue.Full exception.

I know that this isn't an issue for single-threaded use of Queue, but just pointing it out for the follow-up PR.

# Create the initial set of blocks to prefetch.
while not sliding_window.full():
try:
sliding_window.put(next(block_ref_iter))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be sliding_window.put_nowait() since we'd rather throw an error if the Queue somehow ends up being full (e.g. due to a bug) rather than hanging forever. Same with sliding_window.put() and sliding_window.get() below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know...will keep this in mind for next PR since we changed to collections.deque in this one.

for block_ref in block_ref_iter:
yield block_ref
if clear_block_after_read:
ray._private.internal_api.free(block_ref, local_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting thing to note is that this block ref clearing assumes that block_ref is no longer in use after control is returned to this generator, so this assumes no buffering by downstream generators, which may or may not hold true for future tweaks. We should keep this in mind.

Copy link
Contributor

@c21 c21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactoring, thanks @amogkam!

# Signal to the batcher that there are no more blocks to add.
batcher.done_adding()

# Get any leftover batches in ShufflingBatcher.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ShufflingBatcher -> batcher?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is specifically ShufflingBatcher.

Regular Batcher will no longer have any full batches at this point. But ShufflingBatcher may still have full batches if the shuffle buffer size is larger than the batch size.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
@amogkam amogkam requested review from c21 and clarkzinzow January 7, 2023 04:07
@amogkam
Copy link
Contributor Author

amogkam commented Jan 7, 2023

Thanks for the review guys! I updated the PR, please take another look!

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Copy link
Contributor

@c21 c21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, only thing is the sliding_window.queue line!

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
@c21
Copy link
Contributor

c21 commented Jan 9, 2023

just FYI, seems have some CI test failure (example):

ray.exceptions.RayTaskError(TypeError): ray::RayTrainWorker._RayTrainWorker__execute() (pid=27294, ip=172.16.16.3, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7fb2e055d790>)
--
  | File "/ray/python/ray/train/_internal/worker_group.py", line 31, in __execute
  | raise skipped from exception_cause(skipped)
  | File "/ray/python/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
  | train_func(*args, **kwargs)
  | File "/root/.cache/bazel/_bazel_root/5fe90af4e7d1ed9fcf52f59e39e126f5/execroot/com_github_ray_project_ray/bazel-out/k8-opt/bin/doc/datasets_train.runfiles/com_github_ray_project_ray/doc/source/ray-core/_examples/datasets_train/datasets_train.py", line 439, in train_func
  | train_torch_dataset, net, device, criterion, optimizer
  | File "/root/.cache/bazel/_bazel_root/5fe90af4e7d1ed9fcf52f59e39e126f5/execroot/com_github_ray_project_ray/bazel-out/k8-opt/bin/doc/datasets_train.runfiles/com_github_ray_project_ray/doc/source/ray-core/_examples/datasets_train/datasets_train.py", line 340, in train_epoch
  | for i, (inputs, labels) in enumerate(dataset):
  | File "/ray/python/ray/data/_internal/torch_iterable_dataset.py", line 10, in __iter__
  | yield from it
  | File "/ray/python/ray/data/dataset.py", line 3031, in make_generator
  | local_shuffle_seed=local_shuffle_seed,
  | File "/ray/python/ray/data/dataset_pipeline.py", line 213, in iter_batches
  | shuffle_seed=local_shuffle_seed,
  | TypeError: batch_block_refs() takes 1 positional argument but 2 positional arguments (and 7 keyword-only arguments) were given


Signed-off-by: amogkam <amogkamsetty@yahoo.com>
@amogkam amogkam merged commit 86ec3e2 into ray-project:master Jan 10, 2023
@amogkam amogkam deleted the refactor-block-batching branch January 10, 2023 02:06
AmeerHajAli pushed a commit that referenced this pull request Jan 12, 2023
As discussed offline with @clarkzinzow (#30190 (comment)), this PR refactors block batching to follow a chained iterators pattern.

This allows for more flexibility, composability, and better testing of components upstream of Iterator[Block] (formatting, shuffling, batching, prefetching).

This PR only does a refactor and adds tests. There are no API or functionality changes in this PR. This PR also consolidates the map_batches and iter_batches codepaths.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants