Skip to content

Commit

Permalink
pyg-lib ToHeteroModule (#5992)
Browse files Browse the repository at this point in the history
[benchmark of new
HeteroModule](https://github.com/puririshi98/rgcn_pyg_lib_forward_bench/blob/main/heteromodule_bench.py):
<img width="907" alt="image"
src="https://user-images.githubusercontent.com/20074092/202813853-557421f4-e5cd-4ba4-a523-dcdb328e7808.png">
sufficient acceleration achieved
Remaining TODO for the next PR:
- integrate into transformer setup

Co-authored-by: Rishi Puri <riship@riship-mlt.client.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rishi Puri <riship@riship-mlt.nvidia.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
5 people authored Jan 21, 2023
1 parent 08ce552 commit 041508d
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `ToHeteroLinear` and `ToHeteroMessagePassing` modules to accelerate `to_hetero` functionality ([#5992](https://github.com/pyg-team/pytorch_geometric/pull/5992), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456))
- Added `GraphMaskExplainer` ([#6284](https://github.com/pyg-team/pytorch_geometric/pull/6284))
- Added the `GRBCD` and `PRBCD` adversarial attack models ([#5972](https://github.com/pyg-team/pytorch_geometric/pull/5972))
- Added `dropout` option to `SetTransformer` and `GraphMultisetTransformer` ([#6484](https://github.com/pyg-team/pytorch_geometric/pull/6484))
Expand Down
59 changes: 59 additions & 0 deletions test/nn/test_to_hetero_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest # noqa
import torch

from torch_geometric.nn.conv import SAGEConv
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.to_hetero_module import (
ToHeteroLinear,
ToHeteroMessagePassing,
)


@pytest.mark.parametrize('LinearCls', [torch.nn.Linear, Linear])
def test_to_hetero_linear(LinearCls):
x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)}
x = torch.cat([x_dict['1'], x_dict['2']], dim=0)
type_vec = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])

module = ToHeteroLinear(LinearCls(16, 32), list(x_dict.keys()))

out_dict = module(x_dict)
assert len(out_dict) == 2
assert out_dict['1'].size() == (5, 32)
assert out_dict['2'].size() == (4, 32)

out = module(x, type_vec)
assert out.size() == (9, 32)

assert torch.allclose(out_dict['1'], out[0:5])
assert torch.allclose(out_dict['2'], out[5:9])


def test_to_hetero_message_passing():
x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)}
x = torch.cat([x_dict['1'], x_dict['2']], dim=0)
node_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])

edge_index_dict = {
('1', 'to', '2'): torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 3]]),
('2', 'to', '1'): torch.tensor([[0, 0, 1, 2, 3], [0, 1, 2, 3, 4]]),
}
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 5, 6, 7, 8],
[5, 5, 6, 7, 8, 0, 1, 2, 3, 4],
])
edge_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

module = ToHeteroMessagePassing(SAGEConv(16, 32), list(x_dict.keys()),
list(edge_index_dict.keys()))

out_dict = module(x_dict, edge_index_dict)
assert len(out_dict) == 2
assert out_dict['1'].size() == (5, 32)
assert out_dict['2'].size() == (4, 32)

out = module(x, edge_index, node_type, edge_type)
assert out.size() == (9, 32)

assert torch.allclose(out_dict['1'], out[0:5])
assert torch.allclose(out_dict['2'], out[5:9])
176 changes: 176 additions & 0 deletions torch_geometric/nn/to_hetero_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import copy
import warnings
from typing import Dict, List, Optional, Union

import torch
from torch import Tensor

import torch_geometric
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import scatter


class ToHeteroLinear(torch.nn.Module):
def __init__(
self,
module: torch.nn.Module,
types: Union[List[NodeType], List[EdgeType]],
):
from torch_geometric.nn import HeteroLinear, Linear

super().__init__()

self.types = types

if isinstance(module, Linear):
in_channels = module.in_channels
out_channels = module.out_channels
bias = module.bias is not None

elif isinstance(module, torch.nn.Linear):
in_channels = module.in_features
out_channels = module.out_features
bias = module.bias is not None

else:
raise ValueError(f"Expected 'Linear' module (got '{type(module)}'")

# TODO: Need to handle `in_channels=-1` case.
# TODO We currently assume that `x` is sorted according to `type`.
self.hetero_module = HeteroLinear(
in_channels,
out_channels,
num_types=len(types),
is_sorted=True,
bias=bias,
)

def fused_forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
return self.hetero_module(x, type_vec)

def dict_forward(
self,
x_dict: Dict[Union[NodeType, EdgeType], Tensor],
) -> Dict[Union[NodeType, EdgeType], Tensor]:

if not torch_geometric.typing.WITH_PYG_LIB:
return {
key: self.heteromodule.lins[i](x_dict[key])
for i, key in enumerate(self.types)
}

x = torch.cat([x_dict[key] for key in self.types], dim=0)
sizes = [x_dict[key].size(0) for key in self.types]
type_vec = torch.arange(len(self.types), device=x.device)
size = torch.tensor(sizes, device=x.device)
type_vec = type_vec.repeat_interleave(size)
outs = self.hetero_module(x, type_vec).split(sizes)
return {key: out for key, out in zip(self.types, outs)}

def forward(
self,
x: Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]],
type_vec: Optional[Tensor] = None,
) -> Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]]:

if isinstance(x, dict):
return self.dict_forward(x)

elif isinstance(x, Tensor) and type_vec is not None:
return self.fused_forward(x, type_vec)

raise ValueError(f"Encountered invalid forward types in "
f"'{self.__class__.__name__}'")


class ToHeteroMessagePassing(torch.nn.Module):
def __init__(
self,
module: torch.nn.Module,
node_types: List[NodeType],
edge_types: List[NodeType],
aggr: str = 'sum',
):
from torch_geometric.nn import HeteroConv, MessagePassing

super().__init__()

self.node_types = node_types
self.node_type_to_index = {key: i for i, key in enumerate(node_types)}
self.edge_types = edge_types

if not isinstance(module, MessagePassing):
raise ValueError(f"Expected 'MessagePassing' module "
f"(got '{type(module)}'")

if (not hasattr(module, 'reset_parameters')
and sum([p.numel() for p in module.parameters()]) > 0):
warnings.warn(f"'{module}' will be duplicated, but its parameters "
f"cannot be reset. To suppress this warning, add a "
f"'reset_parameters()' method to '{module}'")

convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
self.hetero_module = HeteroConv(convs, aggr)
self.hetero_module.reset_parameters()

def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor,
edge_type: Tensor) -> Tensor:
# TODO This currently does not fuse at all :(
# TODO We currently assume that `x` and `edge_index` are both sorted
# according to `type`.

node_sizes = scatter(torch.ones_like(node_type), node_type, dim=0,
dim_size=len(self.node_types), reduce='sum')
edge_sizes = scatter(torch.ones_like(edge_type), edge_type, dim=0,
dim_size=len(self.edge_types), reduce='sum')

cumsum = torch.cat([node_type.new_zeros(1), node_sizes.cumsum(0)[:1]])

xs = x.split(node_sizes.tolist())
x_dict = {node_type: x for node_type, x in zip(self.node_types, xs)}

# TODO Consider out-sourcing to its own function.
edge_indices = edge_index.clone().split(edge_sizes.tolist(), dim=1)
for (src, _, dst), index in zip(self.edge_types, edge_indices):
index[0] -= cumsum[self.node_type_to_index[src]]
index[1] -= cumsum[self.node_type_to_index[dst]]

edge_index_dict = {
edge_type: edge_index
for edge_type, edge_index in zip(self.edge_types, edge_indices)
}

out_dict = self.hetero_module(x_dict, edge_index_dict)
return torch.cat([out_dict[key] for key in self.node_types], dim=0)

def dict_forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Tensor],
**kwargs,
) -> Dict[NodeType, Tensor]:
return self.hetero_module(x_dict, edge_index_dict, **kwargs)

def forward(
self,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
node_type: OptTensor = None,
edge_type: OptTensor = None,
**kwargs,
) -> Union[Tensor, Dict[NodeType, Tensor]]:

if isinstance(x, dict) and isinstance(edge_index, dict):
return self.dict_forward(x, edge_index, **kwargs)

elif (isinstance(x, Tensor) and isinstance(edge_index, Tensor)
and node_type is not None and edge_type is not None):

if len(kwargs) > 0:
raise ValueError("Additional forward arguments not yet "
"supported in fused mode")

return self.fused_forward(x, edge_index, node_type, edge_type)

raise ValueError(f"Encountered invalid forward types in "
f"'{self.__class__.__name__}'")

0 comments on commit 041508d

Please sign in to comment.