Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check on add_self_loops in HeteroConv and to_hetero #4647

Merged
merged 6 commits into from
May 15, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
- Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626))
- Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644))
Expand Down
15 changes: 14 additions & 1 deletion test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def test_hetero_conv(aggr):
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
}, aggr=aggr)

assert len(list(conv.parameters())) > 0
Expand Down Expand Up @@ -77,3 +78,15 @@ def test_hetero_conv_with_custom_conv():
assert len(out) == 2
assert out['paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)


class MessagePassingLoops(MessagePassing):
def __init__(self):
super().__init__()
self.add_self_loops = True


def test_hetero_conv_self_loop_error():
HeteroConv({('a', 'to', 'a'): MessagePassingLoops()})
with pytest.raises(ValueError, match="incorrect message passing"):
HeteroConv({('a', 'to', 'b'): MessagePassingLoops()})
26 changes: 26 additions & 0 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import pytest
import torch
from torch import Tensor
from torch.nn import Linear, ReLU, Sequential
Expand Down Expand Up @@ -363,3 +364,28 @@ def test_graph_level_to_hetero():
model = to_hetero(model, metadata, aggr='mean', debug=False)
out = model(x_dict, edge_index_dict, batch_dict)
assert out.size() == (1, 64)


class MessagePassingLoops(MessagePassing):
def __init__(self):
super().__init__()
self.add_self_loops = True

def forward(self, x):
return x


class ModelLoops(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = MessagePassingLoops()

def forward(self, x):
return self.conv(x)


def test_hetero_transformer_self_loop_error():
to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')]))
with pytest.raises(ValueError, match="incorrect message passing"):
to_hetero(ModelLoops(), metadata=(['a', 'b'], [('a', 'to', 'b'),
('b', 'to', 'a')]))
4 changes: 4 additions & 0 deletions torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torch_geometric.nn.conv.hgt_conv import group
from torch_geometric.typing import Adj, EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops


class HeteroConv(Module):
Expand Down Expand Up @@ -47,6 +48,9 @@ def __init__(self, convs: Dict[EdgeType, Module],
aggr: Optional[str] = "sum"):
super().__init__()

for edge_type, module in convs.items():
check_add_self_loops(module, [edge_type])

src_node_types = set([key[0] for key in convs.keys()])
dst_node_types = set([key[-1] for key in convs.keys()])
if len(src_node_types - dst_node_types) > 0:
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/nn/to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import torch
from torch.nn import Module

from torch_geometric.nn.fx import Transformer
from torch_geometric.nn.fx import Transformer, get_submodule
from torch_geometric.typing import EdgeType, Metadata, NodeType
from torch_geometric.utils.hetero import get_unused_node_types
from torch_geometric.utils.hetero import (
check_add_self_loops,
get_unused_node_types,
)

try:
from torch.fx import Graph, GraphModule, Node
Expand Down Expand Up @@ -168,6 +171,9 @@ def call_message_passing_module(self, node: Node, target: Any, name: str):
# Add calls to edge type-wise `MessagePassing` modules and aggregate
# the outputs to node type-wise embeddings afterwards.

module = get_submodule(self.module, target)
check_add_self_loops(module, self.metadata[1])

# Group edge-wise keys per destination:
key_name, keys_per_dst = {}, defaultdict(list)
for key in self.metadata[1]:
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/utils/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,12 @@ def get_unused_node_types(node_types: List[NodeType],
edge_types: List[EdgeType]) -> Set[NodeType]:
dst_node_types = set(edge_type[-1] for edge_type in edge_types)
return set(node_types) - set(dst_node_types)


def check_add_self_loops(module: torch.nn.Module, edge_types: List[EdgeType]):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
is_bipartite = any([key[0] != key[-1] for key in edge_types])
if is_bipartite and getattr(module, 'add_self_loops', False):
raise ValueError(
f"'add_self_loops' attribute set to 'True' on module '{module}' "
f"for use with edge type(s) '{edge_types}'. This will lead to "
f"incorrect message passing results.")