From 517015fbb691c91dee535f2441697eacee368418 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Fri, 1 Mar 2019 20:15:03 +0000 Subject: [PATCH 1/6] implement pytorch spmm with gather and scatter add --- python/dgl/backend/pytorch/tensor.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index a709e229f95e..f7bb815fe81b 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -135,15 +135,12 @@ def zeros_like(input): def ones(shape, dtype, ctx): return th.ones(shape, dtype=dtype, device=ctx) -if TH_VERSION.version[0] == 0: - # TODO(minjie): note this does not support autograd on the `x` tensor. - # should adopt a workaround using custom op. - def spmm(x, y): - return th.spmm(x, y) -else: - # torch v1.0+ - def spmm(x, y): - return th.sparse.mm(x, y) +def spmm(x, y): + dst, src = x._indices() + dim = 0 + dst = dst.view(-1, 1).expand_as(-1, y.shape[1]) + zeros = y.new_full((x.shape[0], y.shape[1]), 0) + return zeros.scatter_add(0, dst, y[src] * x._values().unsqueeze(-1)) def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): y = th.zeros(n_segs, *input.shape[1:]).to(input) From f9610a0804229398e079aa9aa394ef698cda5065 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Fri, 1 Mar 2019 20:16:57 +0000 Subject: [PATCH 2/6] fix --- python/dgl/backend/pytorch/tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index f7bb815fe81b..a1ccf52b9b6f 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -137,8 +137,7 @@ def ones(shape, dtype, ctx): def spmm(x, y): dst, src = x._indices() - dim = 0 - dst = dst.view(-1, 1).expand_as(-1, y.shape[1]) + dst = dst.view(-1, 1).expand(-1, y.shape[1]) zeros = y.new_full((x.shape[0], y.shape[1]), 0) return zeros.scatter_add(0, dst, y[src] * x._values().unsqueeze(-1)) From 5ce93ffa596cc11586a699cb9450375689150ef2 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Fri, 1 Mar 2019 23:41:37 +0000 Subject: [PATCH 3/6] replace torch take with index_select --- python/dgl/backend/pytorch/tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index a1ccf52b9b6f..054cd2ffa2d7 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -139,7 +139,8 @@ def spmm(x, y): dst, src = x._indices() dst = dst.view(-1, 1).expand(-1, y.shape[1]) zeros = y.new_full((x.shape[0], y.shape[1]), 0) - return zeros.scatter_add(0, dst, y[src] * x._values().unsqueeze(-1)) + message = th.index_select(y, 0, src) * x._values().unsqueeze(-1) + return zeros.scatter_add(0, dst, message) def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): y = th.zeros(n_segs, *input.shape[1:]).to(input) From 13cb98a73e650929cb79ae8b380ebd9369473ebf Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Sat, 2 Mar 2019 02:41:36 +0000 Subject: [PATCH 4/6] comments --- python/dgl/backend/pytorch/tensor.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index 054cd2ffa2d7..6c489f8d03c1 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -137,10 +137,13 @@ def ones(shape, dtype, ctx): def spmm(x, y): dst, src = x._indices() - dst = dst.view(-1, 1).expand(-1, y.shape[1]) - zeros = y.new_full((x.shape[0], y.shape[1]), 0) - message = th.index_select(y, 0, src) * x._values().unsqueeze(-1) - return zeros.scatter_add(0, dst, message) + # scatter index + index = dst.view(-1, 1).expand(-1, y.shape[1]) + # zero tensor to be scatter_add to + out = y.new_full((x.shape[0], y.shape[1]), 0) + # look up src features and multiply by edge features + feature = th.index_select(y, 0, src) * x._values().unsqueeze(-1) + return out.scatter_add(0, index, feature) def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): y = th.zeros(n_segs, *input.shape[1:]).to(input) From adecb868779874ea686ea04a993e6290f70d1f60 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Sat, 2 Mar 2019 03:40:06 +0000 Subject: [PATCH 5/6] comment about pytorch __getitem__ operator pitfall --- python/dgl/backend/pytorch/tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index 6c489f8d03c1..8ae2417569f4 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -142,6 +142,8 @@ def spmm(x, y): # zero tensor to be scatter_add to out = y.new_full((x.shape[0], y.shape[1]), 0) # look up src features and multiply by edge features + # Note: using y[src] instead of index_select will lead to terrible + # porformance in backward feature = th.index_select(y, 0, src) * x._values().unsqueeze(-1) return out.scatter_add(0, index, feature) From c266dd85b798929131f9159f1a4b70f35169ed31 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Sat, 2 Mar 2019 03:41:47 +0000 Subject: [PATCH 6/6] typo --- python/dgl/backend/pytorch/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index 8ae2417569f4..5fec0fe99c3f 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -143,7 +143,7 @@ def spmm(x, y): out = y.new_full((x.shape[0], y.shape[1]), 0) # look up src features and multiply by edge features # Note: using y[src] instead of index_select will lead to terrible - # porformance in backward + # performance in backward feature = th.index_select(y, 0, src) * x._values().unsqueeze(-1) return out.scatter_add(0, index, feature)