Skip to content

Commit

Permalink
Add test based on reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Jul 26, 2023
1 parent 2995c62 commit f12d473
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _get_renumber_map(df):
map_starting_offset = map.iloc[0]
renumber_map = map[map_starting_offset:].dropna().reset_index(drop=True)
renumber_map_batch_indices = map[1:map_starting_offset].reset_index(drop=True)
renumber_map_batch_indices = renumber_map_batch_indices - map_starting_offset

# Drop all rows with NaN values
df.dropna(axis=0, how="all", inplace=True)
Expand Down
34 changes: 34 additions & 0 deletions python/cugraph-dgl/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,47 @@

import cudf
import cupy as cp
import numpy as np
import torch

from cugraph_dgl.dataloading.utils.sampling_helpers import cast_to_tensor
from cugraph_dgl.dataloading.utils.sampling_helpers import _get_renumber_map


def test_casting_empty_array():
ar = cp.zeros(shape=0, dtype=cp.int32)
ser = cudf.Series(ar)
output_tensor = cast_to_tensor(ser)
assert output_tensor.dtype == torch.int32


def get_dummy_sampled_df():
df = cudf.DataFrame()
df["sources"] = [0, 0, 0, 0, 0, 0, np.nan, np.nan, np.nan]
df["destinations"] = [1, 2, 1, 2, 1, 2, np.nan, np.nan, np.nan]
df["batch_id"] = [0, 0, 1, 1, 2, 2, np.nan, np.nan, np.nan]
df["hop_id"] = [0, 1, 0, 1, 0, 1, np.nan, np.nan, np.nan]

df["map"] = [3, 6, 9, 10, 11, 12, 13, 14, 15]
df = df.astype("int32")
df["hop_id"] = df["hop_id"].astype("uint8")
return df


def test_get_renumber_map():
sampled_df = get_dummy_sampled_df()

df, renumber_map, renumber_map_batch_indices = _get_renumber_map(sampled_df)
# Ensure that map was dropped
assert "map" not in df.columns

expected_map = torch.as_tensor(
[10, 11, 12, 13, 14, 15], dtype=torch.int32, device="cuda"
)
assert torch.equal(renumber_map, expected_map)

expected_batch_indices = torch.as_tensor([3, 6], dtype=torch.int32, device="cuda")
assert torch.equal(renumber_map_batch_indices, expected_batch_indices)

# Ensure we dropped the Nans for rows corresponding to the renumber_map
assert len(df) == 6

0 comments on commit f12d473

Please sign in to comment.