Skip to content

Commit

Permalink
Fix None graphbolt mask will always be set 0 automatically.
Browse files Browse the repository at this point in the history
  • Loading branch information
CfromBU committed Dec 12, 2024
1 parent 88f109f commit 74b2fb5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
12 changes: 7 additions & 5 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def collate(self, items):
raise NotImplementedError

@staticmethod
def add_edge_attribute_to_graph(g, data_name):
def add_edge_attribute_to_graph(g, data_name, gb_padding=0):
"""Add data into the graph as an edge attribute.
For some cases such as prob/mask-based sampling on GraphBolt partitions,
Expand All @@ -329,7 +329,7 @@ def add_edge_attribute_to_graph(g, data_name):
The name of data that's stored in DistGraph.ndata/edata.
"""
if g._use_graphbolt and data_name:
g.add_edge_attribute(data_name)
g.add_edge_attribute(data_name, gb_padding)


class NodeCollator(Collator):
Expand Down Expand Up @@ -366,7 +366,7 @@ class NodeCollator(Collator):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""

def __init__(self, g, nids, graph_sampler):
def __init__(self, g, nids, graph_sampler, gb_padding=0):
self.g = g
if not isinstance(nids, Mapping):
assert (
Expand All @@ -380,7 +380,7 @@ def __init__(self, g, nids, graph_sampler):
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down Expand Up @@ -612,6 +612,7 @@ def __init__(
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
gb_padding=0,
):
self.g = g
if not isinstance(eids, Mapping):
Expand Down Expand Up @@ -642,7 +643,7 @@ def __init__(
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down Expand Up @@ -864,6 +865,7 @@ def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
else:
dataloader_kwargs[k] = v

collator_kwargs["gb_padding"] = 1
if device is None:
# for the distributed case default to the CPU
device = "cpu"
Expand Down
16 changes: 10 additions & 6 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape):
class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""

def __init__(self, name, kv_names):
def __init__(self, name, kv_names, padding=0):
self._name = name
self._kv_names = kv_names
self._padding = padding

def __getstate__(self):
return self._name, self._kv_names
return self._name, self._kv_names, self._padding

def __setstate__(self, state):
self._name, self._kv_names = state
self._name, self._kv_names, self._padding = state

def process_request(self, server_state):
# For now, this is only used to add prob/mask data to the graph.
Expand All @@ -169,7 +170,10 @@ def process_request(self, server_state):
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges
attr_data = torch.zeros(num_edges, dtype=data_type)
if self._padding == 0:
attr_data = torch.zeros(num_edges, dtype=data_type)
else:
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
homo_eids = g.edge_attributes[EID][:num_inner_edges]
Expand Down Expand Up @@ -1620,7 +1624,7 @@ def _get_edata_names(self, etype=None):
edata_names.append(name)
return edata_names

def add_edge_attribute(self, name):
def add_edge_attribute(self, name, padding=0):
"""Add an edge attribute into GraphBolt partition from edge data.
Parameters
Expand All @@ -1643,7 +1647,7 @@ def add_edge_attribute(self, name):
]
rpc.send_request(
self._client._main_server_id,
AddEdgeAttributeFromKVRequest(name, kv_names),
AddEdgeAttributeFromKVRequest(name, kv_names, padding),
)
# Wait for the response.
assert rpc.recv_response()._name == name
Expand Down

0 comments on commit 74b2fb5

Please sign in to comment.