From 00b263f7a8eb72717d3655d4da8682b57331a1fa Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 17 Mar 2023 08:52:06 +0800 Subject: [PATCH 1/9] LEconv and LGConv --- test/nn/conv/test_le_conv.py | 12 +++++++++++- test/nn/conv/test_lg_conv.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/nn/conv/test_le_conv.py b/test/nn/conv/test_le_conv.py index bd0165ed9ebe..765a2dd2eafe 100644 --- a/test/nn/conv/test_le_conv.py +++ b/test/nn/conv/test_le_conv.py @@ -1,4 +1,5 @@ import torch +from torch_sparse import SparseTensor from torch_geometric.nn import LEConv from torch_geometric.testing import is_full_test @@ -8,14 +9,23 @@ def test_le_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 + row, col = edge_index x = torch.randn((num_nodes, in_channels)) + adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + adj2 = adj1.to_torch_sparse_coo_tensor() conv = LEConv(in_channels, out_channels) assert str(conv) == 'LEConv(16, 32)' out = conv(x, edge_index) assert out.size() == (num_nodes, out_channels) + assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6) + assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit(x, edge_index).tolist() == out.tolist() + torch.allclose(jit(x, edge_index), out, atol=1e-6) + + t = '(Tensor, SparseTensor, OptTensor) -> Tensor' + jit = torch.jit.script(conv.jittable(t)) + assert torch.allclose(jit(x, adj1.t()), out) diff --git a/test/nn/conv/test_lg_conv.py b/test/nn/conv/test_lg_conv.py index 02ed16921fce..d85f5e9b3f21 100644 --- a/test/nn/conv/test_lg_conv.py +++ b/test/nn/conv/test_lg_conv.py @@ -12,15 +12,19 @@ def test_lg_conv(): value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) + adj3 = adj1.to_torch_sparse_coo_tensor() + adj4 = adj2.to_torch_sparse_coo_tensor() conv = LGConv() assert str(conv) == 'LGConv()' out1 = conv(x, edge_index) assert out1.size() == (4, 8) assert torch.allclose(conv(x, adj1.t()), out1) + assert torch.allclose(conv(x, adj3.t()), out1) out2 = conv(x, edge_index, value) assert out2.size() == (4, 8) assert torch.allclose(conv(x, adj2.t()), out2) + assert torch.allclose(conv(x, adj4.t()), out2) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' From 622b7e33304b3ff3110bbb886c54c20e498f67be Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 17 Mar 2023 08:59:28 +0800 Subject: [PATCH 2/9] NNConv --- test/nn/conv/test_nn_conv.py | 45 ++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/test/nn/conv/test_nn_conv.py b/test/nn/conv/test_nn_conv.py index 5888e4365161..68684ca0eebb 100644 --- a/test/nn/conv/test_nn_conv.py +++ b/test/nn/conv/test_nn_conv.py @@ -14,7 +14,8 @@ def test_nn_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) - adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) + adj1 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) + adj2 = adj1.to_torch_sparse_coo_tensor() nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32)) conv = NNConv(8, 32, nn=nn) @@ -26,20 +27,24 @@ def test_nn_conv(): '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) - assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() - assert conv(x1, adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out, + atol=1e-6) + assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + assert torch.allclose(conv(x1, adj2.transpose(0, 1)), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit(x1, edge_index, value).tolist() == out.tolist() - assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() + assert torch.allclose(jit(x1, edge_index, value), out, atol=1e-6) + assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out, + atol=1e-6) t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit(x1, adj.t()).tolist() == out.tolist() + assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6) - adj = adj.sparse_resize((4, 2)) + adj1 = adj1.sparse_resize((4, 2)) + adj2 = adj1.to_torch_sparse_coo_tensor() conv = NNConv((8, 16), 32, nn=nn) assert str(conv) == ( 'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n' @@ -51,20 +56,26 @@ def test_nn_conv(): out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) - assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1, + atol=1e-6) + assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) + assert torch.allclose(conv((x1, x2), adj2.transpose(0, 1)), out1, + atol=1e-6) + assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) + assert torch.allclose(conv((x1, None), adj2.transpose(0, 1)), out2, + atol=1e-6) if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit((x1, x2), edge_index, value).tolist() == out1.tolist() - assert jit((x1, x2), edge_index, value, - size=(4, 2)).tolist() == out1.tolist() - assert jit((x1, None), edge_index, value, - size=(4, 2)).tolist() == out2.tolist() + assert torch.allclose(jit((x1, x2), edge_index, value), out1, + atol=1e-6) + assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), + out1, atol=1e-6) + assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), + out2, atol=1e-6) t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit((x1, x2), adj.t()).tolist() == out1.tolist() - assert jit((x1, None), adj.t()).tolist() == out2.tolist() + assert torch.allclose(jit((x1, x2), adj1.t()), out1, atol=1e-6) + assert torch.allclose(jit((x1, None), adj1.t()), out2, atol=1e-6) From 80422594887162924e5e2237ec2695af0f450468 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 17 Mar 2023 09:04:30 +0800 Subject: [PATCH 3/9] PANConv --- test/nn/conv/test_pan_conv.py | 10 ++++++++-- torch_geometric/nn/conv/pan_conv.py | 10 +++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/test/nn/conv/test_pan_conv.py b/test/nn/conv/test_pan_conv.py index a5761b2b931c..e62b6bbc2d53 100644 --- a/test/nn/conv/test_pan_conv.py +++ b/test/nn/conv/test_pan_conv.py @@ -8,12 +8,18 @@ def test_pan_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index - adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + adj2 = adj1.to_torch_sparse_coo_tensor() conv = PANConv(16, 32, filter_size=2) assert str(conv) == 'PANConv(16, 32, filter_size=2)' out1, M1 = conv(x, edge_index) assert out1.size() == (4, 32) - out2, M2 = conv(x, adj.t()) + + out2, M2 = conv(x, adj1.t()) assert torch.allclose(out1, out2, atol=1e-6) assert torch.allclose(M1.to_dense(), M2.to_dense()) + + out3, M3 = conv(x, adj2.t()) + assert torch.allclose(out1, out3, atol=1e-6) + assert torch.allclose(M1.to_dense(), M3.to_dense()) diff --git a/torch_geometric/nn/conv/pan_conv.py b/torch_geometric/nn/conv/pan_conv.py index 018645b97cf9..65e3b126d3d9 100644 --- a/torch_geometric/nn/conv/pan_conv.py +++ b/torch_geometric/nn/conv/pan_conv.py @@ -7,7 +7,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, SparseTensor -from torch_geometric.utils import spmm +from torch_geometric.utils import is_torch_sparse_tensor, spmm class PANConv(MessagePassing): @@ -66,8 +66,12 @@ def forward(self, x: Tensor, adj_t: Optional[SparseTensor] = None if isinstance(edge_index, Tensor): - adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], - sparse_sizes=(x.size(0), x.size(0))) + if is_torch_sparse_tensor(edge_index): + # TODO: handle PyTorch sparse tensor directly + adj_t = SparseTensor.from_torch_sparse_coo_tensor(edge_index) + else: + adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], + sparse_sizes=(x.size(0), x.size(0))) elif isinstance(edge_index, SparseTensor): adj_t = edge_index.set_value(None) assert adj_t is not None From 34cce92e22b1ca8fb2740dc5f15f15ad1507489e Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 17 Mar 2023 09:10:50 +0800 Subject: [PATCH 4/9] SignedConv --- test/nn/conv/test_signed_conv.py | 35 ++++++++++++++++---------- torch_geometric/nn/conv/signed_conv.py | 3 ++- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/test/nn/conv/test_signed_conv.py b/test/nn/conv/test_signed_conv.py index f7604ddb203f..f8060c25807c 100644 --- a/test/nn/conv/test_signed_conv.py +++ b/test/nn/conv/test_signed_conv.py @@ -9,7 +9,8 @@ def test_signed_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index - adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + adj2 = adj1.to_torch_sparse_coo_tensor() conv1 = SignedConv(16, 32, first_aggr=True) assert str(conv1) == 'SignedConv(16, 32, first_aggr=True)' @@ -19,34 +20,42 @@ def test_signed_conv(): out1 = conv1(x, edge_index, edge_index) assert out1.size() == (4, 64) - assert conv1(x, adj.t(), adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv1(x, adj1.t(), adj1.t()), out1, atol=1e-6) + assert torch.allclose(conv1(x, adj2.t(), adj2.t()), out1, atol=1e-6) out2 = conv2(out1, edge_index, edge_index) assert out2.size() == (4, 96) - assert conv2(out1, adj.t(), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv2(out1, adj1.t(), adj1.t()), out2, atol=1e-6) + assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, Tensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) - assert jit1(x, edge_index, edge_index).tolist() == out1.tolist() - assert jit2(out1, edge_index, edge_index).tolist() == out2.tolist() + assert torch.allclose(jit1(x, edge_index, edge_index), out1, atol=1e-6) + assert torch.allclose(jit2(out1, edge_index, edge_index), out2, + atol=1e-6) t = '(Tensor, SparseTensor, SparseTensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) - assert jit1(x, adj.t(), adj.t()).tolist() == out1.tolist() - assert jit2(out1, adj.t(), adj.t()).tolist() == out2.tolist() + assert torch.allclose(jit1(x, adj1.t(), adj1.t()), out1, atol=1e-6) + assert torch.allclose(jit2(out1, adj1.t(), adj1.t()), out2, atol=1e-6) - adj = adj.sparse_resize((4, 2)) + adj1 = adj1.sparse_resize((4, 2)) + adj2 = adj1.to_torch_sparse_coo_tensor() assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2], atol=1e-6) - assert torch.allclose(conv1((x, x[:2]), adj.t(), adj.t()), out1[:2], + assert torch.allclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2], + atol=1e-6) + assert torch.allclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6) - assert torch.allclose(conv2((out1, out1[:2]), adj.t(), adj.t()), out2[:2], - atol=1e-6) + assert torch.allclose(conv2((out1, out1[:2]), adj1.t(), adj1.t()), + out2[:2], atol=1e-6) + assert torch.allclose(conv2((out1, out1[:2]), adj2.t(), adj2.t()), + out2[:2], atol=1e-6) if is_full_test(): t = '(PairTensor, Tensor, Tensor) -> Tensor' @@ -60,7 +69,7 @@ def test_signed_conv(): t = '(PairTensor, SparseTensor, SparseTensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) - assert torch.allclose(jit1((x, x[:2]), adj.t(), adj.t()), out1[:2], + assert torch.allclose(jit1((x, x[:2]), adj1.t(), adj1.t()), out1[:2], atol=1e-6) - assert torch.allclose(jit2((out1, out1[:2]), adj.t(), adj.t()), + assert torch.allclose(jit2((out1, out1[:2]), adj1.t(), adj1.t()), out2[:2], atol=1e-6) diff --git a/torch_geometric/nn/conv/signed_conv.py b/torch_geometric/nn/conv/signed_conv.py index 23e209fda96e..07ed050c9514 100644 --- a/torch_geometric/nn/conv/signed_conv.py +++ b/torch_geometric/nn/conv/signed_conv.py @@ -138,7 +138,8 @@ def message(self, x_j: Tensor) -> Tensor: def message_and_aggregate(self, adj_t: SparseTensor, x: PairTensor) -> Tensor: - adj_t = adj_t.set_value(None) + if isinstance(adj_t, SparseTensor): + adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: From a2002993844a1c06a21374909d46b16485e65551 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 17 Mar 2023 09:17:07 +0800 Subject: [PATCH 5/9] WLConv --- test/nn/conv/test_wl_conv.py | 11 +++++++---- torch_geometric/nn/conv/wl_conv.py | 10 +++++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/test/nn/conv/test_wl_conv.py b/test/nn/conv/test_wl_conv.py index ee7b9c25c897..2c2d09a004f5 100644 --- a/test/nn/conv/test_wl_conv.py +++ b/test/nn/conv/test_wl_conv.py @@ -9,16 +9,19 @@ def test_wl_conv(): x1 = torch.tensor([1, 0, 0, 1]) x2 = F.one_hot(x1).to(torch.float) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) - adj_t = SparseTensor.from_edge_index(edge_index).t() + adj_t1 = SparseTensor.from_edge_index(edge_index).t() + adj_t2 = adj_t1.to_torch_sparse_coo_tensor() conv = WLConv() assert str(conv) == 'WLConv()' out = conv(x1, edge_index) assert out.tolist() == [0, 1, 1, 0] - assert conv(x2, edge_index).tolist() == out.tolist() - assert conv(x1, adj_t).tolist() == out.tolist() - assert conv(x2, adj_t).tolist() == out.tolist() + assert torch.allclose(conv(x2, edge_index), out, atol=1e-6) + assert torch.allclose(conv(x1, adj_t1), out, atol=1e-6) + assert torch.allclose(conv(x1, adj_t2), out, atol=1e-6) + assert torch.allclose(conv(x2, adj_t1), out, atol=1e-6) + assert torch.allclose(conv(x2, adj_t2), out, atol=1e-6) assert conv.histogram(out).tolist() == [[2, 2]] assert torch.allclose(conv.histogram(out, norm=True), diff --git a/torch_geometric/nn/conv/wl_conv.py b/torch_geometric/nn/conv/wl_conv.py index 497975b6145c..4a0f2fa516a5 100644 --- a/torch_geometric/nn/conv/wl_conv.py +++ b/torch_geometric/nn/conv/wl_conv.py @@ -4,7 +4,7 @@ from torch import Tensor from torch_geometric.typing import Adj, SparseTensor -from torch_geometric.utils import scatter +from torch_geometric.utils import is_torch_sparse_tensor, scatter class WLConv(torch.nn.Module): @@ -42,8 +42,12 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor: adj_t = edge_index if not isinstance(adj_t, SparseTensor): - adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], - sparse_sizes=(x.size(0), x.size(0))) + if is_torch_sparse_tensor(edge_index): + # TODO: handle PyTorch sparse tensor directly + adj_t = SparseTensor.from_torch_sparse_coo_tensor(edge_index) + else: + adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], + sparse_sizes=(x.size(0), x.size(0))) out = [] _, col, _ = adj_t.coo() From 3e63d04cac2e33525719dfb1216d0385528aaf7a Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 17 Mar 2023 09:20:14 +0800 Subject: [PATCH 6/9] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c217234b31b1..c1c8d0421915 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,7 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193)) - Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187)) - Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161)) -- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897)) +- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6936](https://github.com/pyg-team/pytorch_geometric/pull/6936)) - Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154)) - Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124)) - Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117)) From 587a15bf1de2825983aaf3f3b1eacd5c248d07ec Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 20 Mar 2023 06:15:52 +0000 Subject: [PATCH 7/9] changelog --- test/nn/conv/test_le_conv.py | 5 ++-- test/nn/conv/test_lg_conv.py | 4 ++-- test/nn/conv/test_nn_conv.py | 36 +++++++++++++---------------- test/nn/conv/test_pan_conv.py | 2 +- test/nn/conv/test_signed_conv.py | 33 +++++++++++--------------- test/nn/conv/test_wl_conv.py | 14 +++++------ torch_geometric/nn/conv/pan_conv.py | 14 ++++++++--- torch_geometric/nn/conv/wl_conv.py | 31 +++++++++++++++---------- torch_geometric/utils/sparse.py | 2 +- 9 files changed, 73 insertions(+), 68 deletions(-) diff --git a/test/nn/conv/test_le_conv.py b/test/nn/conv/test_le_conv.py index 765a2dd2eafe..b46aed1fe0b9 100644 --- a/test/nn/conv/test_le_conv.py +++ b/test/nn/conv/test_le_conv.py @@ -9,10 +9,9 @@ def test_le_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 - row, col = edge_index x = torch.randn((num_nodes, in_channels)) - adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) - adj2 = adj1.to_torch_sparse_coo_tensor() + adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + adj2 = adj1.to_torch_sparse_csc_tensor() conv = LEConv(in_channels, out_channels) assert str(conv) == 'LEConv(16, 32)' diff --git a/test/nn/conv/test_lg_conv.py b/test/nn/conv/test_lg_conv.py index d85f5e9b3f21..2ad940f16423 100644 --- a/test/nn/conv/test_lg_conv.py +++ b/test/nn/conv/test_lg_conv.py @@ -12,8 +12,8 @@ def test_lg_conv(): value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) - adj3 = adj1.to_torch_sparse_coo_tensor() - adj4 = adj2.to_torch_sparse_coo_tensor() + adj3 = adj1.to_torch_sparse_csc_tensor() + adj4 = adj2.to_torch_sparse_csc_tensor() conv = LGConv() assert str(conv) == 'LGConv()' diff --git a/test/nn/conv/test_nn_conv.py b/test/nn/conv/test_nn_conv.py index 68684ca0eebb..0e5cf84a557f 100644 --- a/test/nn/conv/test_nn_conv.py +++ b/test/nn/conv/test_nn_conv.py @@ -27,21 +27,19 @@ def test_nn_conv(): '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) - assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out, - atol=1e-6) - assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) - assert torch.allclose(conv(x1, adj2.transpose(0, 1)), out, atol=1e-6) + assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) + assert torch.allclose(conv(x1, adj1.t()), out) + assert torch.allclose(conv(x1, adj2.transpose(0, 1).coalesce()), out) if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit(x1, edge_index, value), out, atol=1e-6) - assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out, - atol=1e-6) + assert torch.allclose(jit(x1, edge_index, value), out) + assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6) + assert torch.allclose(jit(x1, adj1.t()), out) adj1 = adj1.sparse_resize((4, 2)) adj2 = adj1.to_torch_sparse_coo_tensor() @@ -56,24 +54,22 @@ def test_nn_conv(): out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) - assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1, - atol=1e-6) - assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6) - assert torch.allclose(conv((x1, x2), adj2.transpose(0, 1)), out1, - atol=1e-6) - assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6) - assert torch.allclose(conv((x1, None), adj2.transpose(0, 1)), out2, - atol=1e-6) + assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) + assert torch.allclose(conv((x1, x2), adj1.t()), out1) + assert torch.allclose(conv((x1, x2), + adj2.transpose(0, 1).coalesce()), out1) + assert torch.allclose(conv((x1, None), adj1.t()), out2) + assert torch.allclose(conv((x1, None), + adj2.transpose(0, 1).coalesce()), out2) if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, x2), edge_index, value), out1, - atol=1e-6) + assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), - out1, atol=1e-6) + out1) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), - out2, atol=1e-6) + out2) t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) diff --git a/test/nn/conv/test_pan_conv.py b/test/nn/conv/test_pan_conv.py index e62b6bbc2d53..3858f14bf387 100644 --- a/test/nn/conv/test_pan_conv.py +++ b/test/nn/conv/test_pan_conv.py @@ -9,7 +9,7 @@ def test_pan_conv(): edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) - adj2 = adj1.to_torch_sparse_coo_tensor() + adj2 = adj1.to_torch_sparse_csc_tensor() conv = PANConv(16, 32, filter_size=2) assert str(conv) == 'PANConv(16, 32, filter_size=2)' diff --git a/test/nn/conv/test_signed_conv.py b/test/nn/conv/test_signed_conv.py index f8060c25807c..d4d006ee1cff 100644 --- a/test/nn/conv/test_signed_conv.py +++ b/test/nn/conv/test_signed_conv.py @@ -10,7 +10,7 @@ def test_signed_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) - adj2 = adj1.to_torch_sparse_coo_tensor() + adj2 = adj1.to_torch_sparse_csc_tensor() conv1 = SignedConv(16, 32, first_aggr=True) assert str(conv1) == 'SignedConv(16, 32, first_aggr=True)' @@ -20,36 +20,32 @@ def test_signed_conv(): out1 = conv1(x, edge_index, edge_index) assert out1.size() == (4, 64) - assert torch.allclose(conv1(x, adj1.t(), adj1.t()), out1, atol=1e-6) - assert torch.allclose(conv1(x, adj2.t(), adj2.t()), out1, atol=1e-6) + assert torch.allclose(conv1(x, adj1.t(), adj1.t()), out1) + assert torch.allclose(conv1(x, adj2.t(), adj2.t()), out1) out2 = conv2(out1, edge_index, edge_index) assert out2.size() == (4, 96) - assert torch.allclose(conv2(out1, adj1.t(), adj1.t()), out2, atol=1e-6) - assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2, atol=1e-6) + assert torch.allclose(conv2(out1, adj1.t(), adj1.t()), out2) + assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2) if is_full_test(): t = '(Tensor, Tensor, Tensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) - assert torch.allclose(jit1(x, edge_index, edge_index), out1, atol=1e-6) - assert torch.allclose(jit2(out1, edge_index, edge_index), out2, - atol=1e-6) + assert torch.allclose(jit1(x, edge_index, edge_index), out1) + assert torch.allclose(jit2(out1, edge_index, edge_index), out2) t = '(Tensor, SparseTensor, SparseTensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) - assert torch.allclose(jit1(x, adj1.t(), adj1.t()), out1, atol=1e-6) - assert torch.allclose(jit2(out1, adj1.t(), adj1.t()), out2, atol=1e-6) + assert torch.allclose(jit1(x, adj1.t(), adj1.t()), out1) + assert torch.allclose(jit2(out1, adj1.t(), adj1.t()), out2) adj1 = adj1.sparse_resize((4, 2)) - adj2 = adj1.to_torch_sparse_coo_tensor() - assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2], - atol=1e-6) - assert torch.allclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2], - atol=1e-6) - assert torch.allclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2], - atol=1e-6) + adj2 = adj1.to_torch_sparse_csc_tensor() + assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2]) + assert torch.allclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2]) + assert torch.allclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2]) assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), adj1.t(), adj1.t()), @@ -69,7 +65,6 @@ def test_signed_conv(): t = '(PairTensor, SparseTensor, SparseTensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) - assert torch.allclose(jit1((x, x[:2]), adj1.t(), adj1.t()), out1[:2], - atol=1e-6) + assert torch.allclose(jit1((x, x[:2]), adj1.t(), adj1.t()), out1[:2]) assert torch.allclose(jit2((out1, out1[:2]), adj1.t(), adj1.t()), out2[:2], atol=1e-6) diff --git a/test/nn/conv/test_wl_conv.py b/test/nn/conv/test_wl_conv.py index 2c2d09a004f5..fc8ac55a1ccc 100644 --- a/test/nn/conv/test_wl_conv.py +++ b/test/nn/conv/test_wl_conv.py @@ -9,19 +9,19 @@ def test_wl_conv(): x1 = torch.tensor([1, 0, 0, 1]) x2 = F.one_hot(x1).to(torch.float) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) - adj_t1 = SparseTensor.from_edge_index(edge_index).t() - adj_t2 = adj_t1.to_torch_sparse_coo_tensor() + adj1 = SparseTensor.from_edge_index(edge_index) + adj2 = adj1.to_torch_sparse_csc_tensor() conv = WLConv() assert str(conv) == 'WLConv()' out = conv(x1, edge_index) assert out.tolist() == [0, 1, 1, 0] - assert torch.allclose(conv(x2, edge_index), out, atol=1e-6) - assert torch.allclose(conv(x1, adj_t1), out, atol=1e-6) - assert torch.allclose(conv(x1, adj_t2), out, atol=1e-6) - assert torch.allclose(conv(x2, adj_t1), out, atol=1e-6) - assert torch.allclose(conv(x2, adj_t2), out, atol=1e-6) + assert torch.allclose(conv(x2, edge_index), out) + assert torch.allclose(conv(x1, adj1.t()), out) + assert torch.allclose(conv(x1, adj2.t()), out) + assert torch.allclose(conv(x2, adj1.t()), out) + assert torch.allclose(conv(x2, adj2.t()), out) assert conv.histogram(out).tolist() == [[2, 2]] assert torch.allclose(conv.histogram(out, norm=True), diff --git a/torch_geometric/nn/conv/pan_conv.py b/torch_geometric/nn/conv/pan_conv.py index 65e3b126d3d9..704b2a6ea529 100644 --- a/torch_geometric/nn/conv/pan_conv.py +++ b/torch_geometric/nn/conv/pan_conv.py @@ -67,14 +67,22 @@ def forward(self, x: Tensor, adj_t: Optional[SparseTensor] = None if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): - # TODO: handle PyTorch sparse tensor directly - adj_t = SparseTensor.from_torch_sparse_coo_tensor(edge_index) + # TODO Handle PyTorch sparse tensor directly. + if edge_index.layout == torch.sparse_coo: + adj_t = SparseTensor.from_torch_sparse_coo_tensor( + edge_index) + elif edge_index.layout == torch.sparse_csr: + adj_t = SparseTensor.from_torch_sparse_csr_tensor( + edge_index) + else: + raise ValueError(f"Unexpected sparse tensor layout " + f"(got '{edge_index.layout}')") else: adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], sparse_sizes=(x.size(0), x.size(0))) + elif isinstance(edge_index, SparseTensor): adj_t = edge_index.set_value(None) - assert adj_t is not None adj_t = self.panentropy(adj_t, dtype=x.dtype) diff --git a/torch_geometric/nn/conv/wl_conv.py b/torch_geometric/nn/conv/wl_conv.py index 4a0f2fa516a5..0bfac0b5e621 100644 --- a/torch_geometric/nn/conv/wl_conv.py +++ b/torch_geometric/nn/conv/wl_conv.py @@ -4,7 +4,13 @@ from torch import Tensor from torch_geometric.typing import Adj, SparseTensor -from torch_geometric.utils import is_torch_sparse_tensor, scatter +from torch_geometric.utils import ( + degree, + is_sparse, + scatter, + sort_edge_index, + to_edge_index, +) class WLConv(torch.nn.Module): @@ -40,19 +46,20 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor: x = x.argmax(dim=-1) # one-hot -> integer. assert x.dtype == torch.long - adj_t = edge_index - if not isinstance(adj_t, SparseTensor): - if is_torch_sparse_tensor(edge_index): - # TODO: handle PyTorch sparse tensor directly - adj_t = SparseTensor.from_torch_sparse_coo_tensor(edge_index) - else: - adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], - sparse_sizes=(x.size(0), x.size(0))) + if is_sparse(edge_index): + col_and_row, _ = to_edge_index(edge_index) + col = col_and_row[0] + row = col_and_row[1] + else: + edge_index = sort_edge_index(edge_index, num_nodes=x.size(0), + sort_by_row=False) + row, col = edge_index[0], edge_index[1] + + # `col` is sorted, so we can use it to `split` neighbors to groups: + deg = degree(col, x.size(0), dtype=torch.long).tolist() out = [] - _, col, _ = adj_t.coo() - deg = adj_t.storage.rowcount().tolist() - for node, neighbors in zip(x.tolist(), x[col].split(deg)): + for node, neighbors in zip(x.tolist(), x[row].split(deg)): idx = hash(tuple([node] + neighbors.sort()[0].tolist())) if idx not in self.hashmap: self.hashmap[idx] = len(self.hashmap) diff --git a/torch_geometric/utils/sparse.py b/torch_geometric/utils/sparse.py index c24bec521e35..8de55e8a4534 100644 --- a/torch_geometric/utils/sparse.py +++ b/torch_geometric/utils/sparse.py @@ -294,7 +294,7 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor: device=value.device, ) - raise ValueError(f"Expected sparse tensor layout (got '{adj.layout}')") + raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')") def ptr2index(ptr: Tensor) -> Tensor: From 1da0fc569e97f56ef475293bcf169dcc71e159c7 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 20 Mar 2023 06:20:04 +0000 Subject: [PATCH 8/9] changelog --- test/nn/conv/test_le_conv.py | 6 +++--- test/nn/conv/test_nn_conv.py | 4 ++-- test/nn/conv/test_wl_conv.py | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/nn/conv/test_le_conv.py b/test/nn/conv/test_le_conv.py index b46aed1fe0b9..2e7d858aa8d4 100644 --- a/test/nn/conv/test_le_conv.py +++ b/test/nn/conv/test_le_conv.py @@ -17,13 +17,13 @@ def test_le_conv(): assert str(conv) == 'LEConv(16, 32)' out = conv(x, edge_index) assert out.size() == (num_nodes, out_channels) - assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6) - assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6) + assert torch.allclose(conv(x, adj1.t()), out) + assert torch.allclose(conv(x, adj2.t()), out) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - torch.allclose(jit(x, edge_index), out, atol=1e-6) + torch.allclose(jit(x, edge_index), out) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) diff --git a/test/nn/conv/test_nn_conv.py b/test/nn/conv/test_nn_conv.py index 0e5cf84a557f..dac419594418 100644 --- a/test/nn/conv/test_nn_conv.py +++ b/test/nn/conv/test_nn_conv.py @@ -73,5 +73,5 @@ def test_nn_conv(): t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, x2), adj1.t()), out1, atol=1e-6) - assert torch.allclose(jit((x1, None), adj1.t()), out2, atol=1e-6) + assert torch.allclose(jit((x1, x2), adj1.t()), out1) + assert torch.allclose(jit((x1, None), adj1.t()), out2) diff --git a/test/nn/conv/test_wl_conv.py b/test/nn/conv/test_wl_conv.py index fc8ac55a1ccc..5378263616ce 100644 --- a/test/nn/conv/test_wl_conv.py +++ b/test/nn/conv/test_wl_conv.py @@ -17,11 +17,11 @@ def test_wl_conv(): out = conv(x1, edge_index) assert out.tolist() == [0, 1, 1, 0] - assert torch.allclose(conv(x2, edge_index), out) - assert torch.allclose(conv(x1, adj1.t()), out) - assert torch.allclose(conv(x1, adj2.t()), out) - assert torch.allclose(conv(x2, adj1.t()), out) - assert torch.allclose(conv(x2, adj2.t()), out) + assert torch.equal(conv(x2, edge_index), out) + assert torch.equal(conv(x1, adj1.t()), out) + assert torch.equal(conv(x1, adj2.t()), out) + assert torch.equal(conv(x2, adj1.t()), out) + assert torch.equal(conv(x2, adj2.t()), out) assert conv.histogram(out).tolist() == [[2, 2]] assert torch.allclose(conv.histogram(out, norm=True), From 07fd9c1176182a98b00abc6da36a53b740229482 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 20 Mar 2023 06:22:11 +0000 Subject: [PATCH 9/9] changelog --- torch_geometric/nn/conv/wl_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/wl_conv.py b/torch_geometric/nn/conv/wl_conv.py index 0bfac0b5e621..d50e501f775b 100644 --- a/torch_geometric/nn/conv/wl_conv.py +++ b/torch_geometric/nn/conv/wl_conv.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torch_geometric.typing import Adj, SparseTensor +from torch_geometric.typing import Adj from torch_geometric.utils import ( degree, is_sparse,