diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index b1f99869e..294a8efd7 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -11,7 +11,6 @@ import dask import dask.dataframe -import distributed from dask.base import compute_as_if_collection, tokenize from dask.dataframe.core import DataFrame, _concat as dd_concat, new_dd_object from dask.dataframe.shuffle import shuffle_group @@ -325,25 +324,37 @@ def shuffle( rank_to_out_part_ids, ignore_index, ) - distributed.wait(list(result_futures.values())) - del df_groups + wait(list(result_futures.values())) - # Step (c): extract individual dataframe-partitions + # Release dataframes from step (a) + for fut in df_groups: + fut.release() + + # Step (c): extract individual dataframe-partitions. We use `submit()` + # to control where the tasks are executed. + # TODO: can we do this without using `submit()` to avoid the overhead + # of creating a Future for each dataframe partition? name = f"explicit-comms-shuffle-getitem-{tokenize(name)}" dsk = {} - meta = None - for rank, parts in rank_to_out_part_ids.items(): - for i, part_id in enumerate(parts): - dsk[(name, part_id)] = (getitem, result_futures[rank], i) - if meta is None: - # Get the meta from the first output partition - meta = delayed(make_meta)( - delayed(getitem)(result_futures[rank], i) - ).compute() - assert meta is not None + for rank, worker in enumerate(c.worker_addresses): + if rank in workers: + for i, part_id in enumerate(rank_to_out_part_ids[rank]): + dsk[(name, part_id)] = c.client.submit( + getitem, result_futures[rank], i, workers=[worker] + ) + # Get the meta from the first output partition + meta = delayed(make_meta)(next(iter(dsk.values()))).compute() + + # Create a distributed Dataframe from all the pieces divs = [None] * (len(dsk) + 1) - return new_dd_object(dsk, name, meta, divs).persist() + ret = new_dd_object(dsk, name, meta, divs).persist() + wait(ret) + + # Release all temporary dataframes + for fut in [*result_futures.values(), *dsk.values()]: + fut.release() + return ret def get_rearrange_by_column_tasks_wrapper(func): diff --git a/docs/source/conf.py b/docs/source/conf.py index 2f7825a32..08d8bfdfd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -189,7 +189,8 @@ # -- Extension configuration ------------------------------------------------- - def setup(app): app.add_css_file("https://docs.rapids.ai/assets/css/custom.css") - app.add_js_file("https://docs.rapids.ai/assets/js/custom.js", loading_method="defer") + app.add_js_file( + "https://docs.rapids.ai/assets/js/custom.js", loading_method="defer" + )