From a45f6a73a88981e0a530e692c6b4be29453f3096 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 27 Sep 2023 22:24:41 -0700 Subject: [PATCH] Test without moving to dask_cudf frames --- .../cugraph/cugraph/dask/common/part_utils.py | 62 ++++++++++++++++--- .../simpleDistributedGraph.py | 19 ++++-- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/python/cugraph/cugraph/dask/common/part_utils.py b/python/cugraph/cugraph/dask/common/part_utils.py index 7c0aad6c3ee..25311902b29 100644 --- a/python/cugraph/cugraph/dask/common/part_utils.py +++ b/python/cugraph/cugraph/dask/common/part_utils.py @@ -99,19 +99,65 @@ def _chunk_lst(ls, num_parts): return [ls[i::num_parts] for i in range(num_parts)] -def persist_dask_df_equal_parts_per_worker(dask_df, client): +def persist_dask_df_equal_parts_per_worker( + dask_df, client, return_type="dask_cudf.DataFrame" +): + """ + Persist dask_df with equal parts per worker + Args: + dask_df: dask_cudf.DataFrame + client: dask.distributed.Client + return_type: str, "dask_cudf.DataFrame" or "dict" + Returns: + persisted_keys: dict of {worker: [persisted_keys]} + """ + if return_type not in ["dask_cudf.DataFrame", "dict"]: + raise ValueError("return_type must be either 'dask_cudf.DataFrame' or 'dict'") + ddf_keys = dask_df.to_delayed() workers = client.scheduler_info()["workers"].keys() ddf_keys_ls = _chunk_lst(ddf_keys, len(workers)) - persisted_keys = [] + persisted_keys_d = {} for w, ddf_k in zip(workers, ddf_keys_ls): - persisted_keys.extend( - client.persist(ddf_k, workers=w, allow_other_workers=False) + persisted_keys_d[w] = client.compute( + ddf_k, workers=w, allow_other_workers=False, pure=False ) - dask_df = dask_cudf.from_delayed(persisted_keys, meta=dask_df._meta).persist() - wait(dask_df) - client.rebalance(dask_df) - return dask_df + + persisted_keys_ls = [ + item for sublist in persisted_keys_d.values() for item in sublist + ] + wait(persisted_keys_ls) + if return_type == "dask_cudf.DataFrame": + dask_df = dask_cudf.from_delayed( + persisted_keys_ls, meta=dask_df._meta + ).persist() + wait(dask_df) + return dask_df + + return persisted_keys_d + + +def get_length_of_parts(persisted_keys_d, client): + """ + Get the length of each partition + Args: + persisted_keys_d: dict of {worker: [persisted_keys]} + client: dask.distributed.Client + Returns: + length_of_parts: dict of {worker: [length_of_parts]} + """ + length_of_parts = {} + for w, p_keys in persisted_keys_d.items(): + length_of_parts[w] = [ + client.submit( + len, p_key, pure=False, workers=[w], allow_other_workers=False + ) + for p_key in p_keys + ] + + for w, len_futures in length_of_parts.items(): + length_of_parts[w] = client.gather(len_futures) + return length_of_parts async def _extract_partitions( diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 01885c2d1c3..8e7294503b5 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -37,6 +37,7 @@ from cugraph.structure.symmetrize import symmetrize from cugraph.dask.common.part_utils import ( get_persisted_df_worker_map, + get_length_of_parts, persist_dask_df_equal_parts_per_worker, ) from cugraph.dask import get_n_workers @@ -322,9 +323,13 @@ def __from_edgelist( is_symmetric=not self.properties.directed, ) ddf = ddf.repartition(npartitions=len(workers) * 2) - ddf = persist_dask_df_equal_parts_per_worker(ddf, _client) - num_edges = len(ddf) - ddf = get_persisted_df_worker_map(ddf, _client) + persisted_keys_d = persist_dask_df_equal_parts_per_worker( + ddf, _client, return_type="dict" + ) + length_of_parts = get_length_of_parts(persisted_keys_d, _client) + num_edges = sum( + [item for sublist in length_of_parts.values() for item in sublist] + ) delayed_tasks_d = { w: delayed(simpleDistributedGraphImpl._make_plc_graph)( Comms.get_session_id(), @@ -335,14 +340,16 @@ def __from_edgelist( store_transposed, num_edges, ) - for w, edata in ddf.items() + for w, edata in persisted_keys_d.items() } - del ddf self._plc_graph = { - w: _client.compute(delayed_task, workers=w, allow_other_workers=False) + w: _client.compute( + delayed_task, workers=w, allow_other_workers=False, pure=False + ) for w, delayed_task in delayed_tasks_d.items() } wait(list(self._plc_graph.values())) + del ddf del delayed_tasks_d _client.run(gc.collect)