Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Jul 25, 2023
1 parent aab0b56 commit 7979a0a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,11 @@ def _call_plc_uniform_neighbor_sample_legacy(
random_state=random_state,
return_hops=return_hops,
)

output = convert_to_cudf(
cp_arrays, weight_t, with_edge_properties, return_offsets=return_offsets
)

if isinstance(output, (list, tuple)) and len(output) == 1:
return output[0]
return output
Expand Down
10 changes: 6 additions & 4 deletions python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
graph,
seeds_per_call: int = 200_000,
batches_per_partition: int = 100,
renumber: bool=False,
renumber: bool = False,
log_level: int = None,
**kwargs,
):
Expand Down Expand Up @@ -107,7 +107,7 @@ def batch_size(self) -> int:
@property
def batches_per_partition(self) -> int:
return self.__batches_per_partition

@property
def renumber(self) -> bool:
return self.__renumber
Expand Down Expand Up @@ -254,7 +254,7 @@ def flush(self) -> None:
with_batch_ids=True,
with_edge_properties=True,
return_offsets=True,
renumber=self.__renumber
renumber=self.__renumber,
)

if self.__renumber:
Expand All @@ -281,7 +281,9 @@ def flush(self) -> None:
# Write batches to parquet
self.__write(samples, offsets, renumber_map)
if isinstance(self.__batches, dask_cudf.DataFrame):
futures = [f.release() for f in futures_of(samples)] + [f.release() for f in futures_of(offsets)]
futures = [f.release() for f in futures_of(samples)] + [
f.release() for f in futures_of(offsets)
]
if renumber_map is not None:
futures += [f.release() for f in futures_of(renumber_map)]
wait(futures)
Expand Down
4 changes: 3 additions & 1 deletion python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def _write_samples_to_parquet(
renumber_map_o = cudf.concat(
[
offsets_p.renumber_map_offsets + map_offset,
cudf.Series([len(renumber_map_p) + len(offsets_p) + 1], dtype="int32"),
cudf.Series(
[len(renumber_map_p) + len(offsets_p) + 1], dtype="int32"
),
]
)

Expand Down
6 changes: 3 additions & 3 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_bulk_sampler_simple(scratch_dir):
bs.flush()

recovered_samples = cudf.read_parquet(samples_path)
assert 'map' not in recovered_samples.columns
assert "map" not in recovered_samples.columns

for b in batches["batch"].unique().values_host.tolist():
assert b in recovered_samples["batch_id"].values_host.tolist()
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_bulk_sampler_remainder(scratch_dir):
bs.flush()

recovered_samples = cudf.read_parquet(samples_path)
assert 'map' not in recovered_samples.columns
assert "map" not in recovered_samples.columns

for b in batches["batch"].unique().values_host.tolist():
assert b in recovered_samples["batch_id"].values_host.tolist()
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_bulk_sampler_large_batch_size(scratch_dir):
bs.flush()

recovered_samples = cudf.read_parquet(samples_path)
assert 'map' not in recovered_samples.columns
assert "map" not in recovered_samples.columns

for b in batches["batch"].unique().values_host.tolist():
assert b in recovered_samples["batch_id"].values_host.tolist()
Expand Down
4 changes: 2 additions & 2 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_bulk_sampler_simple(dask_client, scratch_dir):
bs.flush()

recovered_samples = cudf.read_parquet(samples_path)
assert 'map' not in recovered_samples.columns
assert "map" not in recovered_samples.columns

for b in batches["batch"].unique().compute().values_host.tolist():
assert b in recovered_samples["batch_id"].values_host.tolist()
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_bulk_sampler_mg_graph_sg_input(dask_client, scratch_dir):
bs.flush()

recovered_samples = cudf.read_parquet(samples_path)
assert 'map' not in recovered_samples.columns
assert "map" not in recovered_samples.columns

for b in batches["batch"].unique().values_host.tolist():
assert b in recovered_samples["batch_id"].values_host.tolist()
Expand Down

0 comments on commit 7979a0a

Please sign in to comment.