From 3cc32a9775c397f4174ca28852bee2818319c0ce Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Fri, 1 Mar 2019 22:51:51 -0500 Subject: [PATCH] [Feature] Reduce messages with scatter_add in PyTorch (#427) * implement pytorch spmm with gather and scatter add * fix * replace torch take with index_select * comments * comment about pytorch __getitem__ operator pitfall * typo --- python/dgl/backend/pytorch/tensor.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index a709e229f95e..5fec0fe99c3f 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -135,15 +135,17 @@ def zeros_like(input): def ones(shape, dtype, ctx): return th.ones(shape, dtype=dtype, device=ctx) -if TH_VERSION.version[0] == 0: - # TODO(minjie): note this does not support autograd on the `x` tensor. - # should adopt a workaround using custom op. - def spmm(x, y): - return th.spmm(x, y) -else: - # torch v1.0+ - def spmm(x, y): - return th.sparse.mm(x, y) +def spmm(x, y): + dst, src = x._indices() + # scatter index + index = dst.view(-1, 1).expand(-1, y.shape[1]) + # zero tensor to be scatter_add to + out = y.new_full((x.shape[0], y.shape[1]), 0) + # look up src features and multiply by edge features + # Note: using y[src] instead of index_select will lead to terrible + # performance in backward + feature = th.index_select(y, 0, src) * x._values().unsqueeze(-1) + return out.scatter_add(0, index, feature) def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): y = th.zeros(n_segs, *input.shape[1:]).to(input)