Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.sparse.mm benchmark script #6950

Merged
merged 8 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
28 changes: 13 additions & 15 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,21 @@ def test_my_conv_basic():
row, col = edge_index
value = torch.randn(row.size(0))
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
torch_adj_t = adj.to_torch_sparse_csr_tensor().t()
torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr)
torch_adj = adj.to_torch_sparse_csr_tensor()

conv = MyConv(8, 32)
out = conv(x1, edge_index, value)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out, atol=1e-6)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv(x1, adj.t()), out)
assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
conv.fuse = True

adj = adj.sparse_resize((4, 2))
torch_adj_t = adj.to_torch_sparse_csr_tensor().t()
torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr)
torch_adj = adj.to_torch_sparse_csr_tensor()

conv = MyConv((8, 16), 32)
out1 = conv((x1, x2), edge_index, value)
Expand All @@ -79,21 +77,21 @@ def test_my_conv_basic():
assert out2.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6)
conv.fuse = True
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)

# Test backward compatibility for `torch.sparse` tensors:
# Test gradient computation for `torch.sparse` tensors:
conv.fuse = True
torch_adj_t = torch_adj_t.requires_grad_()
conv((x1, x2), torch_adj_t).sum().backward()
assert torch_adj_t.grad is not None
torch_adj = torch_adj.requires_grad_()
out = conv((x1, x2), torch_adj.t().to_sparse_csr())
out.sum().backward()
assert torch_adj.grad is not None


def test_my_conv_out_of_bounds():
Expand Down
4 changes: 1 addition & 3 deletions test/nn/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from torch_geometric.profile import benchmark
from torch_geometric.testing import onlyLinux, withCUDA, withPackage
from torch_geometric.testing import onlyLinux, withCUDA
from torch_geometric.utils import scatter


Expand Down Expand Up @@ -41,7 +41,6 @@ def fused_gather_scatter(x, edge_index, reduce=['sum', 'mean', 'max']):

@withCUDA
@onlyLinux
@withPackage('torch>=2.0.0')
def test_torch_compile(device):
x = torch.randn(10, 16, device=device)
edge_index = torch.randint(0, x.size(0), (2, 40), device=device)
Expand Down Expand Up @@ -71,7 +70,6 @@ def test_torch_compile(device):

@withCUDA
@onlyLinux
@withPackage('torch>=2.0.0')
def test_dynamic_torch_compile(device):
compiled_gather_scatter = torch.compile(gather_scatter)

Expand Down
85 changes: 70 additions & 15 deletions test/utils/test_spmm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import itertools
import warnings

import pytest
import torch
from torch import Tensor

from torch_geometric.profile import benchmark
from torch_geometric.testing import withCUDA
from torch_geometric.typing import WITH_PT2, SparseTensor
from torch_geometric.utils import spmm
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import spmm, to_torch_coo_tensor


@withCUDA
Expand All @@ -14,7 +18,7 @@ def test_spmm_basic(device, reduce):
other = torch.randn(4, 8, device=device)

out1 = src @ other
out2 = spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce=reduce)
out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)
out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert out1.size() == (5, 8)
if reduce == 'sum':
Expand All @@ -24,7 +28,7 @@ def test_spmm_basic(device, reduce):

# Test `mean` reduction with isolated nodes:
src[0] = 0.
out2 = spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce=reduce)
out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)
out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert out1.size() == (5, 8)
assert torch.allclose(out2, out3, atol=1e-6)
Expand All @@ -36,17 +40,32 @@ def test_spmm_reduce(device, reduce):
src = torch.randn(5, 4, device=device)
other = torch.randn(4, 8, device=device)

if WITH_PT2:
if src.is_cuda:
with pytest.raises(NotImplementedError, match="doesn't exist"):
spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce)
else:
out1 = spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce)
out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert torch.allclose(out1, out2)
if src.is_cuda:
with pytest.raises(NotImplementedError, match="not yet supported"):
spmm(src.to_sparse_csr(), other, reduce)
else:
with pytest.raises(ValueError, match="not supported"):
spmm(src.to_sparse(), other, reduce)
out1 = spmm(src.to_sparse_csr(), other, reduce)
out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert torch.allclose(out1, out2)


@withCUDA
@pytest.mark.parametrize(
'layout', [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc])
@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])
def test_spmm_layout(device, layout, reduce):
src = torch.randn(5, 4, device=device)
src = src.to_sparse(layout=layout)
other = torch.randn(4, 8, device=device)

if src.is_cuda and reduce in {'min', 'max'}:
with pytest.raises(NotImplementedError, match="not yet supported"):
spmm(src, other, reduce=reduce)
elif layout != torch.sparse_csr:
with pytest.warns(UserWarning, match="Converting sparse tensor"):
spmm(src, other, reduce=reduce)
else:
spmm(src, other, reduce=reduce)


@pytest.mark.parametrize('reduce', ['sum', 'mean'])
Expand All @@ -65,9 +84,45 @@ def jit_torch(src: Tensor, other: Tensor, reduce: str) -> Tensor:

out1 = src @ other
out2 = jit_torch_sparse(SparseTensor.from_dense(src), other, reduce=reduce)
out3 = jit_torch(src.to_sparse(layout=torch.sparse_csr), other, reduce)
out3 = jit_torch(src.to_sparse_csr(), other, reduce)
assert out1.size() == (5, 8)
if reduce == 'sum':
assert torch.allclose(out1, out2, atol=1e-6)
assert torch.allclose(out1, out3, atol=1e-6)
assert torch.allclose(out2, out3, atol=1e-6)


if __name__ == '__main__':
import argparse

warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*")
warnings.filterwarnings('ignore', ".*Converting sparse tensor to CSR.*")

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

num_nodes, num_edges = 10_000, 200_000
x = torch.randn(num_nodes, 64, device=args.device)
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)

reductions = ['sum', 'mean']
if not x.is_cuda:
reductions.extend(['min', 'max'])
layouts = [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc]

for reduce, layout in itertools.product(reductions, layouts):
print(f'Aggregator: {reduce}, Layout: {layout}')

adj = to_torch_coo_tensor(edge_index, size=num_nodes)
adj = adj.to_sparse(layout=layout)

benchmark(
funcs=[spmm],
func_names=['spmm'],
args=(adj, x, reduce),
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
backward=args.backward,
)
6 changes: 4 additions & 2 deletions torch_geometric/profile/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch import Tensor

from torch_geometric.utils import is_torch_sparse_tensor


def benchmark(
funcs: List[Callable],
Expand Down Expand Up @@ -51,8 +53,8 @@ def benchmark(
for i in range(num_warmups + num_steps):
args = [
arg.detach().requires_grad_(backward)
if isinstance(arg, Tensor) and arg.is_floating_point() else arg
for arg in args
if isinstance(arg, Tensor) and arg.is_floating_point()
and not is_torch_sparse_tensor(arg) else arg for arg in args
]

if torch.cuda.is_available():
Expand Down
51 changes: 45 additions & 6 deletions torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch_geometric.typing
from torch_geometric.typing import Adj, SparseTensor, torch_sparse
from torch_geometric.utils import degree, is_torch_sparse_tensor
from torch_geometric.utils import degree, is_torch_sparse_tensor, scatter


@torch.jit._overload
Expand Down Expand Up @@ -54,20 +54,59 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
# `torch.sparse.mm` only supports reductions on CPU for PyTorch>=2.0.
# This will currently throw on error for CUDA tensors.
if torch_geometric.typing.WITH_PT2:
if src.layout != torch.sparse_csr:

if src.is_cuda and (reduce == 'min' or reduce == 'max'):
raise NotImplementedError(f"`{reduce}` reduction is not yet "
f"supported for 'torch.sparse.Tensor' "
f"on device '{src.device}'")

# Always convert COO to CSR for more efficient processing:
if src.layout == torch.sparse_coo:
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")
src = src.to_sparse_csr()
src = src.to_sparse(layout=torch.sparse_csr)

# Warn in case of CSC format without gradient computation:
if src.layout == torch.sparse_csc and not other.requires_grad:
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")

# Use the default code path for `sum` reduction (works on CPU/GPU):
if reduce == 'sum':
return torch.sparse.mm(src, other)
elif reduce == 'mean' and src.is_cuda:
ptr = src.crow_indices()
deg = ptr[1:] - ptr[:-1]

# Use the default code path with custom reduction (works on CPU):
if src.layout == torch.sparse_csr and not src.is_cuda:
return torch.sparse.mm(src, other, reduce)

# Simulate `mean` reduction by dividing by degree:
if reduce == 'mean':
if src.layout == torch.sparse_csr:
ptr = src.crow_indices()
deg = ptr[1:] - ptr[:-1]
else:
assert src.layout == torch.sparse_csc
deg = scatter(src.values(), src.row_indices(), dim=0,
dim_size=src.size(0), reduce='sum')

return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)

# TODO The `torch.sparse.mm` code path with the `reduce` argument does
# not yet support CSC :(
if src.layout == torch.sparse_csc:
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")
src = src.to_sparse(layout=torch.sparse_csr)

return torch.sparse.mm(src, other, reduce)

# PyTorch < 2.0 only supports sparse COO format:
if reduce == 'sum':
return torch.sparse.mm(src, other)
elif reduce == 'mean':
Expand Down