Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix multiple send recv #320

Merged
merged 25 commits into from
Dec 24, 2018
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def unique(input):
"""
pass

def full_1d(length, fill_value):
def full_1d(length, fill_value, dtype, ctx):
"""Create a 1D tensor full of the fill_value.

Parameters
Expand All @@ -627,6 +627,10 @@ def full_1d(length, fill_value):
The length of the vector.
fill_value : int
The filled value.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def unique(input):
tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype)

def full_1d(length, fill_value):
return nd.full((length,), fill_value)
def full_1d(length, fill_value, dtype, ctx):
return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)

def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
def unique(input):
return th.unique(input)

def full_1d(length, fill_value):
return th.full((length,), fill_value)
def full_1d(length, fill_value, dtype, ctx):
return th.full((length,), fill_value, dtype=dtype, device=ctx)

def nonzero_1d(input):
return th.nonzero(input).squeeze()
Expand Down
14 changes: 14 additions & 0 deletions python/dgl/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,20 @@ def _append(self, other):
# directly updating columns.
self._columns = {key: Column.create(data) for key, data in other.items()}
else:
# pad columns that are not provided in the other frame with initial values
for key, col in self.items():
if key not in other:
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(other.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows,
self._num_rows + other.num_rows)
jermainewang marked this conversation as resolved.
Show resolved Hide resolved
)
other[key] = new_data
# append other to self
for key, col in other.items():
if key not in self._columns:
# the column does not exist; init a new column
Expand Down
54 changes: 27 additions & 27 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(self,
# graph
self._readonly=readonly
self._graph = create_graph_index(graph_data, multigraph, readonly)
# frame
# node and edge frame
if node_frame is None:
self._node_frame = FrameRef(Frame(num_rows=self.number_of_nodes()))
else:
Expand All @@ -188,10 +188,13 @@ def __init__(self,
self._edge_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
else:
self._edge_frame = edge_frame
# msg graph & frame
self._msg_graph = create_graph_index(multigraph=multigraph)
self._msg_frame = FrameRef()
self.reset_messages()
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges())
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
self._msg_frame.set_initializer(dgl.init.zero_initializer)
BarclayII marked this conversation as resolved.
Show resolved Hide resolved
# registered functions
self._message_func = None
self._reduce_func = None
Expand Down Expand Up @@ -243,7 +246,6 @@ def add_nodes(self, num, data=None):
[1., 1., 1., 1.]])
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
if data is None:
# Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num)
Expand Down Expand Up @@ -303,6 +305,9 @@ def add_edge(self, u, v, data=None):
self._edge_frame.add_rows(1)
else:
self._edge_frame.append(data)
# resize msg_index and msg_frame
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)

def add_edges(self, u, v, data=None):
"""Add multiple edges for list of source nodes u and destination nodes
Expand Down Expand Up @@ -353,12 +358,16 @@ def add_edges(self, u, v, data=None):
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
num = max(len(u), len(v))
if data is None:
# Initialize feature placeholders if there are features existing
# NOTE: use max due to edge broadcasting syntax
self._edge_frame.add_rows(max(len(u), len(v)))
self._edge_frame.add_rows(num)
else:
self._edge_frame.append(data)
# initialize feature placeholder for messages
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)

def clear(self):
"""Remove all nodes and edges, as well as their features, from the
Expand All @@ -382,7 +391,7 @@ def clear(self):
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_index = utils.zero_index(0)
self._msg_frame.clear()

def clear_cache(self):
Expand All @@ -394,12 +403,6 @@ def clear_cache(self):
"""
self._graph.clear_cache()

def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())

def number_of_nodes(self):
"""Return the number of nodes in the graph.

Expand Down Expand Up @@ -1168,7 +1171,9 @@ def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())

# copy attributes
def _batcher(lst):
if F.is_tensor(lst[0]):
Expand Down Expand Up @@ -1225,7 +1230,8 @@ def from_scipy_sparse_matrix(self, a):
self._graph.from_scipy_sparse_matrix(a)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())

def node_attr_schemes(self):
"""Return the node feature schemes.
Expand Down Expand Up @@ -1934,7 +1940,7 @@ def send(self, edges=ALL, message_func="default"):
message_func = self._message_func

if is_all(edges):
eid = ALL
eid = utils.toindex(slice(0, self.number_of_edges()))
u, v, _ = self._graph.edges()
elif isinstance(edges, tuple):
u, v = edges
Expand All @@ -1946,14 +1952,15 @@ def send(self, edges=ALL, message_func="default"):
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)

if len(eid) == 0:
# no edge to be triggered
return

with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
message_func=message_func)
Runtime.run(prog)

# update message graph and frame
self._msg_graph.add_edges(u, v)

def recv(self,
v=ALL,
reduce_func="default",
Expand Down Expand Up @@ -2039,10 +2046,6 @@ def recv(self,
apply_node_func = self._apply_node_func
assert reduce_func is not None

if self._msg_frame.num_rows == 0:
# no message has ever been sent
return

if is_all(v):
v = F.arange(0, self.number_of_nodes())
elif isinstance(v, int):
Expand All @@ -2060,9 +2063,6 @@ def recv(self,
inplace=inplace)
Runtime.run(prog)

# FIXME(minjie): multi send bug
self.reset_messages()

def send_and_recv(self,
edges,
message_func="default",
Expand Down
30 changes: 30 additions & 0 deletions python/dgl/runtime/ir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class OpCode(object):
WRITE_DICT_ = 24
APPEND_ROW_ = 25
WRITE_ROW_INPLACE_ = 26
CLEAR_FRAME_ = 27

class Executor(object):
@abstractmethod
Expand Down Expand Up @@ -645,3 +646,32 @@ def run(self):
def APPEND_ROW_(fd1, fd2):
reg = IR_REGISTRY[OpCode.APPEND_ROW_]
get_current_prog().issue(reg['executor_cls'](fd1, fd2))

class ClearFrame_Executor(Executor):
def __init__(self, fd):
self.fd = fd

def opcode(self):
return OpCode.CLEAR_FRAME_

def arg_vars(self):
return [self.fd]

def ret_var(self):
return None

def run(self):
frame = self.fd.data
num_rows = frame.num_rows
frame.clear()
frame.add_rows(num_rows)

IR_REGISTRY[OpCode.CLEAR_FRAME_] = {
'name': 'CLEAR_FRAME_',
'args_type': [VarType.FEAT_DICT],
'ret_type': None,
'executor_cls': ClearFrame_Executor,
}
def CLEAR_FRAME_(fd):
reg = IR_REGISTRY[OpCode.CLEAR_FRAME_]
get_current_prog().issue(reg['executor_cls'](fd))
53 changes: 34 additions & 19 deletions python/dgl/runtime/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ def schedule_send(graph, u, v, eid, message_func):
# TODO(minjie): support builtin message func
message_func = _standardize_func_usage(message_func, 'message')
# vars
nf = var.FEAT_DICT(graph._node_frame)
ef = var.FEAT_DICT(graph._edge_frame)
mf = var.FEAT_DICT(graph._msg_frame)
u = var.IDX(u)
v = var.IDX(v)
eid = var.IDX(eid)
msg = _gen_send(graph, nf, ef, u, v, eid, message_func)
# TODO: handle duplicate messages
ir.APPEND_ROW_(mf, msg)
var_nf = var.FEAT_DICT(graph._node_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_mf = var.FEAT_DICT(graph._msg_frame)
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
msg = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, message_func)
ir.WRITE_ROW_(var_mf, var_eid, msg)
# set message indicator to 1
graph._msg_index = graph._msg_index.set_items(eid, 1)

def schedule_recv(graph,
recv_nodes,
Expand All @@ -74,9 +75,16 @@ def schedule_recv(graph,
inplace: bool
If True, the update will be done in place
"""
src, dst, mid = graph._msg_graph.in_edges(recv_nodes)
if len(mid) == 0:
# All recv nodes are 0-degree nodes; downgrade to apply nodes.
src, dst, eid = graph._graph.in_edges(recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._msg_index.get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx)
if len(eid) == 0:
# Downgrade to apply nodes if
# 1) all recv nodes are 0-degree nodes
# 2) no send has been called
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else:
Expand All @@ -86,13 +94,19 @@ def schedule_recv(graph,
recv_nodes = utils.toindex(recv_nodes)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# reduce
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes)
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._msg_index = graph._msg_index.set_items(eid, 0)
if not graph._msg_index.has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf'))

def schedule_snr(graph,
edge_tuples,
Expand Down Expand Up @@ -426,13 +440,14 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
recv_nodes : utils.Index
"""
call_type = "recv"
_, dst, mid = edge_tuples
_, dst, eid = edge_tuples
rfunc = _standardize_func_usage(reduce_func, 'reduce')
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
# TODO(minjie): should replace this with an IR call to make the program
# stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))

# vars
Expand All @@ -444,8 +459,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
# FIXME: refactor this when fixing the multi-recv bug
inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, mid, dst, recv_nodes)
inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, eid, dst,
recv_nodes)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)

if len(rfunc) == 0:
Expand All @@ -456,7 +471,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
rfunc = BundledFunction(rfunc)

# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, mid, dst,
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst,
recv_nodes, nf, msg, out)
return out

Expand Down
Loading