Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
kunmukh authored Oct 27, 2022
2 parents 32c1a5d + ea4d9e8 commit cfc8a3e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
18 changes: 17 additions & 1 deletion tests/tools/create_chunked_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def rand_edges(num_src, num_dst, num_edges):
paper_label = np.random.choice(num_classes, num_papers)
paper_year = np.random.choice(2022, num_papers)
paper_orig_ids = np.arange(0, num_papers)
writes_orig_ids = np.arange(0, g.num_edges('writes'))

# masks.
if include_masks:
Expand Down Expand Up @@ -93,26 +94,38 @@ def rand_edges(num_src, num_dst, num_edges):
paper_feat_path = os.path.join(input_dir, 'paper/feat.npy')
with open(paper_feat_path, 'wb') as f:
np.save(f, paper_feat)
g.nodes['paper'].data['feat'] = torch.from_numpy(paper_feat)

paper_label_path = os.path.join(input_dir, 'paper/label.npy')
with open(paper_label_path, 'wb') as f:
np.save(f, paper_label)
g.nodes['paper'].data['label'] = torch.from_numpy(paper_label)

paper_year_path = os.path.join(input_dir, 'paper/year.npy')
with open(paper_year_path, 'wb') as f:
np.save(f, paper_year)
g.nodes['paper'].data['year'] = torch.from_numpy(paper_year)

paper_orig_ids_path = os.path.join(input_dir, 'paper/orig_ids.npy')
with open(paper_orig_ids_path, 'wb') as f:
np.save(f, paper_orig_ids)
g.nodes['paper'].data['orig_ids'] = torch.from_numpy(paper_orig_ids)

cite_count_path = os.path.join(input_dir, 'cites/count.npy')
with open(cite_count_path, 'wb') as f:
np.save(f, cite_count)
g.edges['cites'].data['count'] = torch.from_numpy(cite_count)

write_year_path = os.path.join(input_dir, 'writes/year.npy')
with open(write_year_path, 'wb') as f:
np.save(f, write_year)
g.edges['writes'].data['year'] = torch.from_numpy(write_year)
g.edges['rev_writes'].data['year'] = torch.from_numpy(write_year)

writes_orig_ids_path = os.path.join(input_dir, 'writes/orig_ids.npy')
with open(writes_orig_ids_path, 'wb') as f:
np.save(f, writes_orig_ids)
g.edges['writes'].data['orig_ids'] = torch.from_numpy(writes_orig_ids)

node_data = None
if include_masks:
Expand Down Expand Up @@ -193,7 +206,10 @@ def rand_edges(num_src, num_dst, num_edges):

edge_data = {
'cites': {'count': cite_count_path},
'writes': {'year': write_year_path},
'writes': {
'year': write_year_path,
'orig_ids': writes_orig_ids_path
},
'rev_writes': {'year': write_year_path},
}

Expand Down
19 changes: 15 additions & 4 deletions tests/tools/test_dist_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def _verify_graph_feats(
ndata = node_feats[ntype + "/" + name][local_nids]
assert torch.equal(ndata, true_feats)

for etype in g.etypes:
for c_etype in g.canonical_etypes:
etype = c_etype[1]
etype_id = g.get_etype_id(etype)
inner_edge_mask = _get_inner_edge_mask(part, etype_id)
inner_eids = part.edata[dgl.EID][inner_edge_mask]
Expand All @@ -75,7 +76,7 @@ def _verify_graph_feats(
continue
true_feats = g.edges[etype].data[name][orig_id]
edata = edge_feats[etype + "/" + name][local_eids]
assert torch.equal(edata == true_feats)
assert torch.equal(edata, true_feats)


@pytest.mark.parametrize("num_chunks", [1, 8])
Expand Down Expand Up @@ -119,13 +120,17 @@ def test_chunk_graph(num_chunks):

# check node_data
output_node_data_dir = os.path.join(output_dir, "node_data", "paper")
for feat in ["feat", "label", "year"]:
for feat in ["feat", "label", "year", "orig_ids"]:
feat_data = []
for i in range(num_chunks):
chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name)
assert os.path.isfile(chunk_f_name)
feat_array = np.load(chunk_f_name)
assert feat_array.shape[0] == num_papers // num_chunks
feat_data.append(feat_array)
feat_data = np.concatenate(feat_data, 0)
assert torch.equal(torch.from_numpy(feat_data), g.nodes['paper'].data[feat])

# check edge_data
num_edges = {
Expand All @@ -137,15 +142,21 @@ def test_chunk_graph(num_chunks):
for etype, feat in [
["paper:cites:paper", "count"],
["author:writes:paper", "year"],
["author:writes:paper", "orig_ids"],
["paper:rev_writes:author", "year"],
]:
feat_data = []
output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
for i in range(num_chunks):
chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name)
assert os.path.isfile(chunk_f_name)
feat_array = np.load(chunk_f_name)
assert feat_array.shape[0] == num_edges[etype] // num_chunks
assert feat_array.shape[0] == num_edges[etype] // num_chunks
feat_data.append(feat_array)
feat_data = np.concatenate(feat_data, 0)
assert torch.equal(torch.from_numpy(feat_data),
g.edges[etype.split(':')[1]].data[feat])


@pytest.mark.parametrize("num_chunks", [1, 3, 8])
Expand Down
9 changes: 9 additions & 0 deletions tools/distpartitioning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,15 @@ def write_edge_features(edge_features, edge_file):
edge_file : string
File in which the edge information is serialized
"""
# TODO[Rui]: Below is a temporary fix for etype and will be
# further refined in the near future as we'll shift to canonical
# etypes entirely.
def format_etype(etype):
etype, name = etype.split('/')
etype = etype.split(':')[1]
return etype + '/' + name
edge_features = {format_etype(etype):
data for etype, data in edge_features.items()}
dgl.data.utils.save_tensors(edge_file, edge_features)

def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes):
Expand Down

0 comments on commit cfc8a3e

Please sign in to comment.