Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Sparse Tensor support: LEConv, LGConv, NNConv, PANConv, SignedConv, and WLConv #6936

Merged
merged 11 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6936](https://github.com/pyg-team/pytorch_geometric/pull/6936), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
11 changes: 10 additions & 1 deletion test/nn/conv/test_le_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.nn import LEConv
from torch_geometric.testing import is_full_test
Expand All @@ -9,13 +10,21 @@ def test_le_conv():
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
num_nodes = edge_index.max().item() + 1
x = torch.randn((num_nodes, in_channels))
adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

conv = LEConv(in_channels, out_channels)
assert str(conv) == 'LEConv(16, 32)'
out = conv(x, edge_index)
assert out.size() == (num_nodes, out_channels)
assert torch.allclose(conv(x, adj1.t()), out)
assert torch.allclose(conv(x, adj2.t()), out)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x, edge_index).tolist() == out.tolist()
torch.allclose(jit(x, edge_index), out)

t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj1.t()), out)
4 changes: 4 additions & 0 deletions test/nn/conv/test_lg_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ def test_lg_conv():
value = torch.rand(row.size(0))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = adj2.set_value(None)
adj3 = adj1.to_torch_sparse_csc_tensor()
adj4 = adj2.to_torch_sparse_csc_tensor()

conv = LGConv()
assert str(conv) == 'LGConv()'
out1 = conv(x, edge_index)
assert out1.size() == (4, 8)
assert torch.allclose(conv(x, adj1.t()), out1)
assert torch.allclose(conv(x, adj3.t()), out1)
out2 = conv(x, edge_index, value)
assert out2.size() == (4, 8)
assert torch.allclose(conv(x, adj2.t()), out2)
assert torch.allclose(conv(x, adj4.t()), out2)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
Expand Down
41 changes: 24 additions & 17 deletions test/nn/conv/test_nn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def test_nn_conv():
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
value = torch.rand(row.size(0), 3)
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32))
conv = NNConv(8, 32, nn=nn)
Expand All @@ -26,20 +27,22 @@ def test_nn_conv():
'))')
out = conv(x1, edge_index, value)
assert out.size() == (4, 32)
assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)
assert torch.allclose(conv(x1, adj1.t()), out)
assert torch.allclose(conv(x1, adj2.transpose(0, 1).coalesce()), out)

if is_full_test():
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index, value).tolist() == out.tolist()
assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
assert torch.allclose(jit(x1, edge_index, value), out)
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)

t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(jit(x1, adj1.t()), out)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()
conv = NNConv((8, 16), 32, nn=nn)
assert str(conv) == (
'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n'
Expand All @@ -51,20 +54,24 @@ def test_nn_conv():
out2 = conv((x1, None), edge_index, value, (4, 2))
assert out1.size() == (2, 32)
assert out2.size() == (2, 32)
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
assert torch.allclose(conv((x1, x2), adj1.t()), out1)
assert torch.allclose(conv((x1, x2),
adj2.transpose(0, 1).coalesce()), out1)
assert torch.allclose(conv((x1, None), adj1.t()), out2)
assert torch.allclose(conv((x1, None),
adj2.transpose(0, 1).coalesce()), out2)

if is_full_test():
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index, value).tolist() == out1.tolist()
assert jit((x1, x2), edge_index, value,
size=(4, 2)).tolist() == out1.tolist()
assert jit((x1, None), edge_index, value,
size=(4, 2)).tolist() == out2.tolist()
assert torch.allclose(jit((x1, x2), edge_index, value), out1)
assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),
out1)
assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),
out2)

t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
assert jit((x1, None), adj.t()).tolist() == out2.tolist()
assert torch.allclose(jit((x1, x2), adj1.t()), out1)
assert torch.allclose(jit((x1, None), adj1.t()), out2)
10 changes: 8 additions & 2 deletions test/nn/conv/test_pan_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ def test_pan_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

conv = PANConv(16, 32, filter_size=2)
assert str(conv) == 'PANConv(16, 32, filter_size=2)'
out1, M1 = conv(x, edge_index)
assert out1.size() == (4, 32)
out2, M2 = conv(x, adj.t())

out2, M2 = conv(x, adj1.t())
assert torch.allclose(out1, out2, atol=1e-6)
assert torch.allclose(M1.to_dense(), M2.to_dense())

out3, M3 = conv(x, adj2.t())
assert torch.allclose(out1, out3, atol=1e-6)
assert torch.allclose(M1.to_dense(), M3.to_dense())
38 changes: 21 additions & 17 deletions test/nn/conv/test_signed_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def test_signed_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

conv1 = SignedConv(16, 32, first_aggr=True)
assert str(conv1) == 'SignedConv(16, 32, first_aggr=True)'
Expand All @@ -19,34 +20,38 @@ def test_signed_conv():

out1 = conv1(x, edge_index, edge_index)
assert out1.size() == (4, 64)
assert conv1(x, adj.t(), adj.t()).tolist() == out1.tolist()
assert torch.allclose(conv1(x, adj1.t(), adj1.t()), out1)
assert torch.allclose(conv1(x, adj2.t(), adj2.t()), out1)

out2 = conv2(out1, edge_index, edge_index)
assert out2.size() == (4, 96)
assert conv2(out1, adj.t(), adj.t()).tolist() == out2.tolist()
assert torch.allclose(conv2(out1, adj1.t(), adj1.t()), out2)
assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2)

if is_full_test():
t = '(Tensor, Tensor, Tensor) -> Tensor'
jit1 = torch.jit.script(conv1.jittable(t))
jit2 = torch.jit.script(conv2.jittable(t))
assert jit1(x, edge_index, edge_index).tolist() == out1.tolist()
assert jit2(out1, edge_index, edge_index).tolist() == out2.tolist()
assert torch.allclose(jit1(x, edge_index, edge_index), out1)
assert torch.allclose(jit2(out1, edge_index, edge_index), out2)

t = '(Tensor, SparseTensor, SparseTensor) -> Tensor'
jit1 = torch.jit.script(conv1.jittable(t))
jit2 = torch.jit.script(conv2.jittable(t))
assert jit1(x, adj.t(), adj.t()).tolist() == out1.tolist()
assert jit2(out1, adj.t(), adj.t()).tolist() == out2.tolist()
assert torch.allclose(jit1(x, adj1.t(), adj1.t()), out1)
assert torch.allclose(jit2(out1, adj1.t(), adj1.t()), out2)

adj = adj.sparse_resize((4, 2))
assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2],
atol=1e-6)
assert torch.allclose(conv1((x, x[:2]), adj.t(), adj.t()), out1[:2],
atol=1e-6)
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_csc_tensor()
assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2])
assert torch.allclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2])
assert torch.allclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2])
assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index),
out2[:2], atol=1e-6)
assert torch.allclose(conv2((out1, out1[:2]), adj.t(), adj.t()), out2[:2],
atol=1e-6)
assert torch.allclose(conv2((out1, out1[:2]), adj1.t(), adj1.t()),
out2[:2], atol=1e-6)
assert torch.allclose(conv2((out1, out1[:2]), adj2.t(), adj2.t()),
out2[:2], atol=1e-6)

if is_full_test():
t = '(PairTensor, Tensor, Tensor) -> Tensor'
Expand All @@ -60,7 +65,6 @@ def test_signed_conv():
t = '(PairTensor, SparseTensor, SparseTensor) -> Tensor'
jit1 = torch.jit.script(conv1.jittable(t))
jit2 = torch.jit.script(conv2.jittable(t))
assert torch.allclose(jit1((x, x[:2]), adj.t(), adj.t()), out1[:2],
atol=1e-6)
assert torch.allclose(jit2((out1, out1[:2]), adj.t(), adj.t()),
assert torch.allclose(jit1((x, x[:2]), adj1.t(), adj1.t()), out1[:2])
assert torch.allclose(jit2((out1, out1[:2]), adj1.t(), adj1.t()),
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
out2[:2], atol=1e-6)
11 changes: 7 additions & 4 deletions test/nn/conv/test_wl_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ def test_wl_conv():
x1 = torch.tensor([1, 0, 0, 1])
x2 = F.one_hot(x1).to(torch.float)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])
adj_t = SparseTensor.from_edge_index(edge_index).t()
adj1 = SparseTensor.from_edge_index(edge_index)
adj2 = adj1.to_torch_sparse_csc_tensor()

conv = WLConv()
assert str(conv) == 'WLConv()'

out = conv(x1, edge_index)
assert out.tolist() == [0, 1, 1, 0]
assert conv(x2, edge_index).tolist() == out.tolist()
assert conv(x1, adj_t).tolist() == out.tolist()
assert conv(x2, adj_t).tolist() == out.tolist()
assert torch.equal(conv(x2, edge_index), out)
assert torch.equal(conv(x1, adj1.t()), out)
assert torch.equal(conv(x1, adj2.t()), out)
assert torch.equal(conv(x2, adj1.t()), out)
assert torch.equal(conv(x2, adj2.t()), out)

assert conv.histogram(out).tolist() == [[2, 2]]
assert torch.allclose(conv.histogram(out, norm=True),
Expand Down
20 changes: 16 additions & 4 deletions torch_geometric/nn/conv/pan_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, SparseTensor
from torch_geometric.utils import spmm
from torch_geometric.utils import is_torch_sparse_tensor, spmm


class PANConv(MessagePassing):
Expand Down Expand Up @@ -66,11 +66,23 @@ def forward(self, x: Tensor,

adj_t: Optional[SparseTensor] = None
if isinstance(edge_index, Tensor):
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0],
sparse_sizes=(x.size(0), x.size(0)))
if is_torch_sparse_tensor(edge_index):
# TODO Handle PyTorch sparse tensor directly.
if edge_index.layout == torch.sparse_coo:
adj_t = SparseTensor.from_torch_sparse_coo_tensor(
edge_index)
elif edge_index.layout == torch.sparse_csr:
adj_t = SparseTensor.from_torch_sparse_csr_tensor(
edge_index)
else:
raise ValueError(f"Unexpected sparse tensor layout "
f"(got '{edge_index.layout}')")
else:
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0],
sparse_sizes=(x.size(0), x.size(0)))

elif isinstance(edge_index, SparseTensor):
adj_t = edge_index.set_value(None)
assert adj_t is not None

adj_t = self.panentropy(adj_t, dtype=x.dtype)

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/signed_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def message(self, x_j: Tensor) -> Tensor:

def message_and_aggregate(self, adj_t: SparseTensor,
x: PairTensor) -> Tensor:
adj_t = adj_t.set_value(None)
if isinstance(adj_t, SparseTensor):
adj_t = adj_t.set_value(None, layout=None)
return spmm(adj_t, x[0], reduce=self.aggr)

def __repr__(self) -> str:
Expand Down
29 changes: 20 additions & 9 deletions torch_geometric/nn/conv/wl_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
import torch
from torch import Tensor

from torch_geometric.typing import Adj, SparseTensor
from torch_geometric.utils import scatter
from torch_geometric.typing import Adj
from torch_geometric.utils import (
degree,
is_sparse,
scatter,
sort_edge_index,
to_edge_index,
)


class WLConv(torch.nn.Module):
Expand Down Expand Up @@ -40,15 +46,20 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
x = x.argmax(dim=-1) # one-hot -> integer.
assert x.dtype == torch.long

adj_t = edge_index
if not isinstance(adj_t, SparseTensor):
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0],
sparse_sizes=(x.size(0), x.size(0)))
if is_sparse(edge_index):
col_and_row, _ = to_edge_index(edge_index)
col = col_and_row[0]
row = col_and_row[1]
else:
edge_index = sort_edge_index(edge_index, num_nodes=x.size(0),
sort_by_row=False)
row, col = edge_index[0], edge_index[1]

# `col` is sorted, so we can use it to `split` neighbors to groups:
deg = degree(col, x.size(0), dtype=torch.long).tolist()

out = []
_, col, _ = adj_t.coo()
deg = adj_t.storage.rowcount().tolist()
for node, neighbors in zip(x.tolist(), x[col].split(deg)):
for node, neighbors in zip(x.tolist(), x[row].split(deg)):
idx = hash(tuple([node] + neighbors.sort()[0].tolist()))
if idx not in self.hashmap:
self.hashmap[idx] = len(self.hashmap)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
device=value.device,
)

raise ValueError(f"Expected sparse tensor layout (got '{adj.layout}')")
raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')")


def ptr2index(ptr: Tensor) -> Tensor:
Expand Down