Skip to content

Commit

Permalink
torch.sparse.mm benchmark script (#6950)
Browse files Browse the repository at this point in the history
```
Aggregator: sum, Layout: torch.sparse_coo
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.0291s   | 1.1011s    | 1.1302s |
+--------+-----------+------------+---------+
Aggregator: sum, Layout: torch.sparse_csr
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.0234s   | 1.0986s    | 1.1220s |
+--------+-----------+------------+---------+
Aggregator: sum, Layout: torch.sparse_csc
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 1.0967s   | 0.0306s    | 1.1273s |
+--------+-----------+------------+---------+
Aggregator: mean, Layout: torch.sparse_coo
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.0320s   | 0.9957s    | 1.0277s |
+--------+-----------+------------+---------+
Aggregator: mean, Layout: torch.sparse_csr
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.0395s   | 1.0311s    | 1.0706s |
+--------+-----------+------------+---------+
Aggregator: mean, Layout: torch.sparse_csc
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 1.1745s   | 0.0431s    | 1.2176s |
+--------+-----------+------------+---------+
Aggregator: min, Layout: torch.sparse_coo
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.1950s   | 1.0798s    | 1.2748s |
+--------+-----------+------------+---------+
Aggregator: min, Layout: torch.sparse_csr
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.2009s   | 1.0845s    | 1.2855s |
+--------+-----------+------------+---------+
Aggregator: min, Layout: torch.sparse_csc
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 1.2554s   | 1.0785s    | 2.3339s |
+--------+-----------+------------+---------+
Aggregator: max, Layout: torch.sparse_coo
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.1663s   | 1.0461s    | 1.2124s |
+--------+-----------+------------+---------+
Aggregator: max, Layout: torch.sparse_csr
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 0.1613s   | 1.0435s    | 1.2048s |
+--------+-----------+------------+---------+
Aggregator: max, Layout: torch.sparse_csc
+--------+-----------+------------+---------+
| Name   | Forward   | Backward   | Total   |
|--------+-----------+------------+---------|
| spmm   | 1.2571s   | 1.0804s    | 2.3375s |
+--------+-----------+------------+---------+
```
  • Loading branch information
rusty1s authored Mar 18, 2023
1 parent b5ecfd9 commit c73597f
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 42 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
28 changes: 13 additions & 15 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,21 @@ def test_my_conv_basic():
row, col = edge_index
value = torch.randn(row.size(0))
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
torch_adj_t = adj.to_torch_sparse_csr_tensor().t()
torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr)
torch_adj = adj.to_torch_sparse_csr_tensor()

conv = MyConv(8, 32)
out = conv(x1, edge_index, value)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out, atol=1e-6)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv(x1, adj.t()), out)
assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
conv.fuse = True

adj = adj.sparse_resize((4, 2))
torch_adj_t = adj.to_torch_sparse_csr_tensor().t()
torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr)
torch_adj = adj.to_torch_sparse_csr_tensor()

conv = MyConv((8, 16), 32)
out1 = conv((x1, x2), edge_index, value)
Expand All @@ -79,21 +77,21 @@ def test_my_conv_basic():
assert out2.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6)
conv.fuse = True
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)

# Test backward compatibility for `torch.sparse` tensors:
# Test gradient computation for `torch.sparse` tensors:
conv.fuse = True
torch_adj_t = torch_adj_t.requires_grad_()
conv((x1, x2), torch_adj_t).sum().backward()
assert torch_adj_t.grad is not None
torch_adj = torch_adj.requires_grad_()
out = conv((x1, x2), torch_adj.t().to_sparse_csr())
out.sum().backward()
assert torch_adj.grad is not None


def test_my_conv_out_of_bounds():
Expand Down
4 changes: 1 addition & 3 deletions test/nn/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from torch_geometric.profile import benchmark
from torch_geometric.testing import onlyLinux, withCUDA, withPackage
from torch_geometric.testing import onlyLinux, withCUDA
from torch_geometric.utils import scatter


Expand Down Expand Up @@ -41,7 +41,6 @@ def fused_gather_scatter(x, edge_index, reduce=['sum', 'mean', 'max']):

@withCUDA
@onlyLinux
@withPackage('torch>=2.0.0')
def test_torch_compile(device):
x = torch.randn(10, 16, device=device)
edge_index = torch.randint(0, x.size(0), (2, 40), device=device)
Expand Down Expand Up @@ -71,7 +70,6 @@ def test_torch_compile(device):

@withCUDA
@onlyLinux
@withPackage('torch>=2.0.0')
def test_dynamic_torch_compile(device):
compiled_gather_scatter = torch.compile(gather_scatter)

Expand Down
85 changes: 70 additions & 15 deletions test/utils/test_spmm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import itertools
import warnings

import pytest
import torch
from torch import Tensor

from torch_geometric.profile import benchmark
from torch_geometric.testing import withCUDA
from torch_geometric.typing import WITH_PT2, SparseTensor
from torch_geometric.utils import spmm
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import spmm, to_torch_coo_tensor


@withCUDA
Expand All @@ -14,7 +18,7 @@ def test_spmm_basic(device, reduce):
other = torch.randn(4, 8, device=device)

out1 = src @ other
out2 = spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce=reduce)
out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)
out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert out1.size() == (5, 8)
if reduce == 'sum':
Expand All @@ -24,7 +28,7 @@ def test_spmm_basic(device, reduce):

# Test `mean` reduction with isolated nodes:
src[0] = 0.
out2 = spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce=reduce)
out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)
out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert out1.size() == (5, 8)
assert torch.allclose(out2, out3, atol=1e-6)
Expand All @@ -36,17 +40,32 @@ def test_spmm_reduce(device, reduce):
src = torch.randn(5, 4, device=device)
other = torch.randn(4, 8, device=device)

if WITH_PT2:
if src.is_cuda:
with pytest.raises(NotImplementedError, match="doesn't exist"):
spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce)
else:
out1 = spmm(src.to_sparse(layout=torch.sparse_csr), other, reduce)
out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert torch.allclose(out1, out2)
if src.is_cuda:
with pytest.raises(NotImplementedError, match="not yet supported"):
spmm(src.to_sparse_csr(), other, reduce)
else:
with pytest.raises(ValueError, match="not supported"):
spmm(src.to_sparse(), other, reduce)
out1 = spmm(src.to_sparse_csr(), other, reduce)
out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)
assert torch.allclose(out1, out2)


@withCUDA
@pytest.mark.parametrize(
'layout', [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc])
@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])
def test_spmm_layout(device, layout, reduce):
src = torch.randn(5, 4, device=device)
src = src.to_sparse(layout=layout)
other = torch.randn(4, 8, device=device)

if src.is_cuda and reduce in {'min', 'max'}:
with pytest.raises(NotImplementedError, match="not yet supported"):
spmm(src, other, reduce=reduce)
elif layout != torch.sparse_csr:
with pytest.warns(UserWarning, match="Converting sparse tensor"):
spmm(src, other, reduce=reduce)
else:
spmm(src, other, reduce=reduce)


@pytest.mark.parametrize('reduce', ['sum', 'mean'])
Expand All @@ -65,9 +84,45 @@ def jit_torch(src: Tensor, other: Tensor, reduce: str) -> Tensor:

out1 = src @ other
out2 = jit_torch_sparse(SparseTensor.from_dense(src), other, reduce=reduce)
out3 = jit_torch(src.to_sparse(layout=torch.sparse_csr), other, reduce)
out3 = jit_torch(src.to_sparse_csr(), other, reduce)
assert out1.size() == (5, 8)
if reduce == 'sum':
assert torch.allclose(out1, out2, atol=1e-6)
assert torch.allclose(out1, out3, atol=1e-6)
assert torch.allclose(out2, out3, atol=1e-6)


if __name__ == '__main__':
import argparse

warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*")
warnings.filterwarnings('ignore', ".*Converting sparse tensor to CSR.*")

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

num_nodes, num_edges = 10_000, 200_000
x = torch.randn(num_nodes, 64, device=args.device)
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)

reductions = ['sum', 'mean']
if not x.is_cuda:
reductions.extend(['min', 'max'])
layouts = [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc]

for reduce, layout in itertools.product(reductions, layouts):
print(f'Aggregator: {reduce}, Layout: {layout}')

adj = to_torch_coo_tensor(edge_index, size=num_nodes)
adj = adj.to_sparse(layout=layout)

benchmark(
funcs=[spmm],
func_names=['spmm'],
args=(adj, x, reduce),
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
backward=args.backward,
)
6 changes: 4 additions & 2 deletions torch_geometric/profile/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch import Tensor

from torch_geometric.utils import is_torch_sparse_tensor


def benchmark(
funcs: List[Callable],
Expand Down Expand Up @@ -51,8 +53,8 @@ def benchmark(
for i in range(num_warmups + num_steps):
args = [
arg.detach().requires_grad_(backward)
if isinstance(arg, Tensor) and arg.is_floating_point() else arg
for arg in args
if isinstance(arg, Tensor) and arg.is_floating_point()
and not is_torch_sparse_tensor(arg) else arg for arg in args
]

if torch.cuda.is_available():
Expand Down
51 changes: 45 additions & 6 deletions torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch_geometric.typing
from torch_geometric.typing import Adj, SparseTensor, torch_sparse
from torch_geometric.utils import degree, is_torch_sparse_tensor
from torch_geometric.utils import degree, is_torch_sparse_tensor, scatter


@torch.jit._overload
Expand Down Expand Up @@ -54,20 +54,59 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
# `torch.sparse.mm` only supports reductions on CPU for PyTorch>=2.0.
# This will currently throw on error for CUDA tensors.
if torch_geometric.typing.WITH_PT2:
if src.layout != torch.sparse_csr:

if src.is_cuda and (reduce == 'min' or reduce == 'max'):
raise NotImplementedError(f"`{reduce}` reduction is not yet "
f"supported for 'torch.sparse.Tensor' "
f"on device '{src.device}'")

# Always convert COO to CSR for more efficient processing:
if src.layout == torch.sparse_coo:
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")
src = src.to_sparse_csr()
src = src.to_sparse(layout=torch.sparse_csr)

# Warn in case of CSC format without gradient computation:
if src.layout == torch.sparse_csc and not other.requires_grad:
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")

# Use the default code path for `sum` reduction (works on CPU/GPU):
if reduce == 'sum':
return torch.sparse.mm(src, other)
elif reduce == 'mean' and src.is_cuda:
ptr = src.crow_indices()
deg = ptr[1:] - ptr[:-1]

# Use the default code path with custom reduction (works on CPU):
if src.layout == torch.sparse_csr and not src.is_cuda:
return torch.sparse.mm(src, other, reduce)

# Simulate `mean` reduction by dividing by degree:
if reduce == 'mean':
if src.layout == torch.sparse_csr:
ptr = src.crow_indices()
deg = ptr[1:] - ptr[:-1]
else:
assert src.layout == torch.sparse_csc
deg = scatter(src.values(), src.row_indices(), dim=0,
dim_size=src.size(0), reduce='sum')

return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)

# TODO The `torch.sparse.mm` code path with the `reduce` argument does
# not yet support CSC :(
if src.layout == torch.sparse_csc:
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")
src = src.to_sparse(layout=torch.sparse_csr)

return torch.sparse.mm(src, other, reduce)

# PyTorch < 2.0 only supports sparse COO format:
if reduce == 'sum':
return torch.sparse.mm(src, other)
elif reduce == 'mean':
Expand Down

0 comments on commit c73597f

Please sign in to comment.