From 36b1d6b25265c1ede2f035ae167d9acf5832f406 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sun, 15 May 2022 15:29:08 +0800 Subject: [PATCH 1/5] add check on self-loops in hetero conv --- test/nn/conv/test_hetero_conv.py | 13 +++++++++++++ torch_geometric/nn/conv/hetero_conv.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index 286a2e2f561b..ebb7163c5189 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -77,3 +77,16 @@ def test_hetero_conv_with_custom_conv(): assert len(out) == 2 assert out['paper'].size() == (50, 64) assert out['author'].size() == (30, 64) + + +class MessagePassingLoops(MessagePassing): + def __init__(self, add_self_loops: bool = True): + super().__init__() + self.add_self_loops = add_self_loops + + +def test_hetero_exception_self_loops(): + + model = MessagePassingLoops() + with pytest.raises(ValueError): + HeteroConv({("a", "to", "b"): model}) diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index 738eb440b28b..4841df53cc37 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -46,6 +46,7 @@ class HeteroConv(Module): def __init__(self, convs: Dict[EdgeType, Module], aggr: Optional[str] = "sum"): super().__init__() + self.validate_convs(convs) src_node_types = set([key[0] for key in convs.keys()]) dst_node_types = set([key[-1] for key in convs.keys()]) @@ -59,6 +60,23 @@ def __init__(self, convs: Dict[EdgeType, Module], self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) self.aggr = aggr + @staticmethod + def validate_convs(convs: Dict[EdgeType, Module]) -> None: + """ + Checks to make sure that the convs provided for bipartite message + passing edges to not have an 'add_self_loops' argument that is set + to true. Raises :attr:`ValueError` if such a model exists. + """ + for edge_type, model in convs.items(): + if edge_type[0] != edge_type[-1]: + if hasattr(model, "add_self_loops"): + if model.add_self_loops: + raise ValueError( + f"bipartite edge_type '{edge_type}' has" + f"'add_self_loops' attribute set " + f"to true. This will lead to incorrect" + f"message passing results.") + def reset_parameters(self): for conv in self.convs.values(): conv.reset_parameters() From 0edc5e41f1e54e0ffcfecd6ad58da1151191e2e7 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sun, 15 May 2022 15:54:10 +0800 Subject: [PATCH 2/5] add check on self-loops in to_hetero --- CHANGELOG.md | 1 + test/nn/conv/test_hetero_conv.py | 2 +- test/nn/test_to_hetero_transformer.py | 17 +++++++++++++ torch_geometric/nn/conv/hetero_conv.py | 12 +++------- torch_geometric/nn/to_hetero_transformer.py | 6 ++++- .../nn/to_hetero_with_bases_transformer.py | 24 ++++++++++++++++++- torch_geometric/utils/hetero.py | 12 ++++++++++ 7 files changed, 62 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3989cff8f5f0..8cf493fefa08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added a check in `HeteroConv` and `to_hetero` to make sure `add_self_loops` is not true ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647)) - Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635)) - Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626)) - Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644)) diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index ebb7163c5189..7c4655bd2263 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -86,7 +86,7 @@ def __init__(self, add_self_loops: bool = True): def test_hetero_exception_self_loops(): - model = MessagePassingLoops() with pytest.raises(ValueError): HeteroConv({("a", "to", "b"): model}) + HeteroConv({("a", "to", "a"): model}) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index 95bb0e7ab426..f5b73e92e1c6 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -1,5 +1,6 @@ from typing import Tuple +import pytest import torch from torch import Tensor from torch.nn import Linear, ReLU, Sequential @@ -363,3 +364,19 @@ def test_graph_level_to_hetero(): model = to_hetero(model, metadata, aggr='mean', debug=False) out = model(x_dict, edge_index_dict, batch_dict) assert out.size() == (1, 64) + + +class MessagePassingLoops(MessagePassing): + def __init__(self, add_self_loops: bool = True): + super().__init__() + self.add_self_loops = add_self_loops + + def forward(self): + pass + + +def test_hetero_transformer_exception_self_loops(): + model = MessagePassingLoops() + with pytest.raises(ValueError): + to_hetero(model, (['a'], (('a', 'to', 'b'), )), aggr='mean') + to_hetero(model, (['a'], (('a', 'to', 'a'), )), aggr='mean') diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index 4841df53cc37..ef7a81889875 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -7,6 +7,7 @@ from torch_geometric.nn.conv.hgt_conv import group from torch_geometric.typing import Adj, EdgeType, NodeType +from torch_geometric.utils.hetero import check_add_self_loops class HeteroConv(Module): @@ -67,15 +68,8 @@ def validate_convs(convs: Dict[EdgeType, Module]) -> None: passing edges to not have an 'add_self_loops' argument that is set to true. Raises :attr:`ValueError` if such a model exists. """ - for edge_type, model in convs.items(): - if edge_type[0] != edge_type[-1]: - if hasattr(model, "add_self_loops"): - if model.add_self_loops: - raise ValueError( - f"bipartite edge_type '{edge_type}' has" - f"'add_self_loops' attribute set " - f"to true. This will lead to incorrect" - f"message passing results.") + for edge_type, module in convs.items(): + check_add_self_loops(module, [edge_type]) def reset_parameters(self): for conv in self.convs.values(): diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 6f1af095e147..25521edcff61 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -8,7 +8,10 @@ from torch_geometric.nn.fx import Transformer from torch_geometric.typing import EdgeType, Metadata, NodeType -from torch_geometric.utils.hetero import get_unused_node_types +from torch_geometric.utils.hetero import ( + check_add_self_loops, + get_unused_node_types, +) try: from torch.fx import Graph, GraphModule, Node @@ -132,6 +135,7 @@ def __init__( debug: bool = False, ): super().__init__(module, input_map, debug) + check_add_self_loops(module, metadata[1]) unused_node_types = get_unused_node_types(*metadata) if len(unused_node_types) > 0: diff --git a/torch_geometric/nn/to_hetero_with_bases_transformer.py b/torch_geometric/nn/to_hetero_with_bases_transformer.py index 8c0be95d5d2b..308ff32c0c57 100644 --- a/torch_geometric/nn/to_hetero_with_bases_transformer.py +++ b/torch_geometric/nn/to_hetero_with_bases_transformer.py @@ -2,6 +2,7 @@ import warnings from typing import Any, Dict, List, Optional, Union +import pytest import torch from torch import Tensor from torch.nn import Module, Parameter @@ -11,7 +12,10 @@ from torch_geometric.nn.dense import Linear from torch_geometric.nn.fx import Transformer from torch_geometric.typing import EdgeType, Metadata, NodeType -from torch_geometric.utils.hetero import get_unused_node_types +from torch_geometric.utils.hetero import ( + check_add_self_loops, + get_unused_node_types, +) try: from torch.fx import Graph, GraphModule, Node @@ -128,6 +132,7 @@ def forward(self, x, edge_index): debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ + transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases, in_channels, input_map, debug) return transformer.transform() @@ -144,6 +149,7 @@ def __init__( debug: bool = False, ): super().__init__(module, input_map, debug) + check_add_self_loops(module, metadata[1]) unused_node_types = get_unused_node_types(*metadata) if len(unused_node_types) > 0: @@ -547,3 +553,19 @@ def split_output( def key2str(key: Union[NodeType, EdgeType]) -> str: key = '__'.join(key) if isinstance(key, tuple) else key return key.replace(' ', '_').replace('-', '_').replace(':', '_') + + +class MessagePassingLoops(MessagePassing): + def __init__(self, add_self_loops: bool = True): + super().__init__() + self.add_self_loops = add_self_loops + + def forward(self): + pass + + +def test_hetero_transformer_exception_self_loops(): + model = MessagePassingLoops() + with pytest.raises(ValueError): + to_hetero_with_bases(model, (('a'), (('a', 'to', 'b'), )), 1) + to_hetero_with_bases(model, (('a'), (('a', 'to', 'a'), )), 1) diff --git a/torch_geometric/utils/hetero.py b/torch_geometric/utils/hetero.py index 5102ce3fb9b0..f2fe115008c7 100644 --- a/torch_geometric/utils/hetero.py +++ b/torch_geometric/utils/hetero.py @@ -47,3 +47,15 @@ def get_unused_node_types(node_types: List[NodeType], edge_types: List[EdgeType]) -> Set[NodeType]: dst_node_types = set(edge_type[-1] for edge_type in edge_types) return set(node_types) - set(dst_node_types) + + +def check_add_self_loops(module: torch.nn.Module, edge_types: List[EdgeType]): + edge_types = [ + edge_type for edge_type in edge_types if edge_type[0] != edge_type[-1] + ] + if len(edge_types) > 0 and hasattr( + module, "add_self_loops") and module.add_self_loops: + raise ValueError( + f"'add_self_loops' attribute set to 'True' on module {module} " + f"for use with edge type(s) '{edge_types}'. This will lead to " + f"incorrect message passing results.") From 6b4277abda4a77d44a5cb75037427cefa99d7ce7 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sun, 15 May 2022 16:01:36 +0800 Subject: [PATCH 3/5] fix old tets --- test/nn/conv/test_hetero_conv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index 7c4655bd2263..526b55298c7d 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -30,9 +30,10 @@ def test_hetero_conv(aggr): conv = HeteroConv( { - ('paper', 'to', 'paper'): GCNConv(-1, 64), + ('paper', 'to', 'paper'): GCNConv(-1, 64, add_self_loops=False), ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), - ('paper', 'to', 'author'): GATConv((-1, -1), 64), + ('paper', 'to', 'author'): GATConv( + (-1, -1), 64, add_self_loops=False), }, aggr=aggr) assert len(list(conv.parameters())) > 0 From f7778bd3c5c4b20c62ef1a568d42a443633332ab Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sun, 15 May 2022 16:17:09 +0800 Subject: [PATCH 4/5] remove pytest import --- .../nn/test_to_hetero_with_bases_transformer.py | 17 +++++++++++++++++ .../nn/to_hetero_with_bases_transformer.py | 17 ----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/test/nn/test_to_hetero_with_bases_transformer.py b/test/nn/test_to_hetero_with_bases_transformer.py index f56dbeeac9ee..ed4a3598980d 100644 --- a/test/nn/test_to_hetero_with_bases_transformer.py +++ b/test/nn/test_to_hetero_with_bases_transformer.py @@ -1,5 +1,6 @@ from typing import Tuple +import pytest import torch from torch import Tensor from torch.nn import Linear, ReLU, Sequential @@ -279,3 +280,19 @@ def test_to_hetero_with_bases_and_rgcn_equal_output(): out3 = model(x_dict, adj_t_dict) out3 = torch.cat([out3['paper'], out3['author']], dim=0) assert torch.allclose(out1, out3, atol=1e-6) + + +class MessagePassingLoops(MessagePassing): + def __init__(self, add_self_loops: bool = True): + super().__init__() + self.add_self_loops = add_self_loops + + def forward(self): + pass + + +def test_hetero_transformer_exception_self_loops(): + model = MessagePassingLoops() + with pytest.raises(ValueError): + to_hetero_with_bases(model, (('a'), (('a', 'to', 'b'), )), 1) + to_hetero_with_bases(model, (('a'), (('a', 'to', 'a'), )), 1) diff --git a/torch_geometric/nn/to_hetero_with_bases_transformer.py b/torch_geometric/nn/to_hetero_with_bases_transformer.py index 308ff32c0c57..1b5fd62cb79c 100644 --- a/torch_geometric/nn/to_hetero_with_bases_transformer.py +++ b/torch_geometric/nn/to_hetero_with_bases_transformer.py @@ -2,7 +2,6 @@ import warnings from typing import Any, Dict, List, Optional, Union -import pytest import torch from torch import Tensor from torch.nn import Module, Parameter @@ -553,19 +552,3 @@ def split_output( def key2str(key: Union[NodeType, EdgeType]) -> str: key = '__'.join(key) if isinstance(key, tuple) else key return key.replace(' ', '_').replace('-', '_').replace(':', '_') - - -class MessagePassingLoops(MessagePassing): - def __init__(self, add_self_loops: bool = True): - super().__init__() - self.add_self_loops = add_self_loops - - def forward(self): - pass - - -def test_hetero_transformer_exception_self_loops(): - model = MessagePassingLoops() - with pytest.raises(ValueError): - to_hetero_with_bases(model, (('a'), (('a', 'to', 'b'), )), 1) - to_hetero_with_bases(model, (('a'), (('a', 'to', 'a'), )), 1) From 1cdd287f7e6948e4562267f6ceb9ae8c69fa898f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 15 May 2022 07:42:10 -0700 Subject: [PATCH 5/5] update --- CHANGELOG.md | 2 +- test/nn/conv/test_hetero_conv.py | 15 +++++------ test/nn/test_to_hetero_transformer.py | 27 ++++++++++++------- .../test_to_hetero_with_bases_transformer.py | 17 ------------ torch_geometric/nn/conv/hetero_conv.py | 14 +++------- torch_geometric/nn/to_hetero_transformer.py | 6 +++-- .../nn/to_hetero_with_bases_transformer.py | 7 +---- torch_geometric/utils/hetero.py | 9 +++---- 8 files changed, 37 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cf493fefa08..b5af6f83fbe2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added -- Added a check in `HeteroConv` and `to_hetero` to make sure `add_self_loops` is not true ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647)) +- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647)) - Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635)) - Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626)) - Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644)) diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index 526b55298c7d..36d93a7ebf7e 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -30,7 +30,7 @@ def test_hetero_conv(aggr): conv = HeteroConv( { - ('paper', 'to', 'paper'): GCNConv(-1, 64, add_self_loops=False), + ('paper', 'to', 'paper'): GCNConv(-1, 64), ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'to', 'author'): GATConv( (-1, -1), 64, add_self_loops=False), @@ -81,13 +81,12 @@ def test_hetero_conv_with_custom_conv(): class MessagePassingLoops(MessagePassing): - def __init__(self, add_self_loops: bool = True): + def __init__(self): super().__init__() - self.add_self_loops = add_self_loops + self.add_self_loops = True -def test_hetero_exception_self_loops(): - model = MessagePassingLoops() - with pytest.raises(ValueError): - HeteroConv({("a", "to", "b"): model}) - HeteroConv({("a", "to", "a"): model}) +def test_hetero_conv_self_loop_error(): + HeteroConv({('a', 'to', 'a'): MessagePassingLoops()}) + with pytest.raises(ValueError, match="incorrect message passing"): + HeteroConv({('a', 'to', 'b'): MessagePassingLoops()}) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index f5b73e92e1c6..917502f87e4c 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -367,16 +367,25 @@ def test_graph_level_to_hetero(): class MessagePassingLoops(MessagePassing): - def __init__(self, add_self_loops: bool = True): + def __init__(self): + super().__init__() + self.add_self_loops = True + + def forward(self, x): + return x + + +class ModelLoops(torch.nn.Module): + def __init__(self): super().__init__() - self.add_self_loops = add_self_loops + self.conv = MessagePassingLoops() - def forward(self): - pass + def forward(self, x): + return self.conv(x) -def test_hetero_transformer_exception_self_loops(): - model = MessagePassingLoops() - with pytest.raises(ValueError): - to_hetero(model, (['a'], (('a', 'to', 'b'), )), aggr='mean') - to_hetero(model, (['a'], (('a', 'to', 'a'), )), aggr='mean') +def test_hetero_transformer_self_loop_error(): + to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')])) + with pytest.raises(ValueError, match="incorrect message passing"): + to_hetero(ModelLoops(), metadata=(['a', 'b'], [('a', 'to', 'b'), + ('b', 'to', 'a')])) diff --git a/test/nn/test_to_hetero_with_bases_transformer.py b/test/nn/test_to_hetero_with_bases_transformer.py index ed4a3598980d..f56dbeeac9ee 100644 --- a/test/nn/test_to_hetero_with_bases_transformer.py +++ b/test/nn/test_to_hetero_with_bases_transformer.py @@ -1,6 +1,5 @@ from typing import Tuple -import pytest import torch from torch import Tensor from torch.nn import Linear, ReLU, Sequential @@ -280,19 +279,3 @@ def test_to_hetero_with_bases_and_rgcn_equal_output(): out3 = model(x_dict, adj_t_dict) out3 = torch.cat([out3['paper'], out3['author']], dim=0) assert torch.allclose(out1, out3, atol=1e-6) - - -class MessagePassingLoops(MessagePassing): - def __init__(self, add_self_loops: bool = True): - super().__init__() - self.add_self_loops = add_self_loops - - def forward(self): - pass - - -def test_hetero_transformer_exception_self_loops(): - model = MessagePassingLoops() - with pytest.raises(ValueError): - to_hetero_with_bases(model, (('a'), (('a', 'to', 'b'), )), 1) - to_hetero_with_bases(model, (('a'), (('a', 'to', 'a'), )), 1) diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index ef7a81889875..95253039534a 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -47,7 +47,9 @@ class HeteroConv(Module): def __init__(self, convs: Dict[EdgeType, Module], aggr: Optional[str] = "sum"): super().__init__() - self.validate_convs(convs) + + for edge_type, module in convs.items(): + check_add_self_loops(module, [edge_type]) src_node_types = set([key[0] for key in convs.keys()]) dst_node_types = set([key[-1] for key in convs.keys()]) @@ -61,16 +63,6 @@ def __init__(self, convs: Dict[EdgeType, Module], self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) self.aggr = aggr - @staticmethod - def validate_convs(convs: Dict[EdgeType, Module]) -> None: - """ - Checks to make sure that the convs provided for bipartite message - passing edges to not have an 'add_self_loops' argument that is set - to true. Raises :attr:`ValueError` if such a model exists. - """ - for edge_type, module in convs.items(): - check_add_self_loops(module, [edge_type]) - def reset_parameters(self): for conv in self.convs.values(): conv.reset_parameters() diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 25521edcff61..233c3c8bc8ee 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -6,7 +6,7 @@ import torch from torch.nn import Module -from torch_geometric.nn.fx import Transformer +from torch_geometric.nn.fx import Transformer, get_submodule from torch_geometric.typing import EdgeType, Metadata, NodeType from torch_geometric.utils.hetero import ( check_add_self_loops, @@ -135,7 +135,6 @@ def __init__( debug: bool = False, ): super().__init__(module, input_map, debug) - check_add_self_loops(module, metadata[1]) unused_node_types = get_unused_node_types(*metadata) if len(unused_node_types) > 0: @@ -172,6 +171,9 @@ def call_message_passing_module(self, node: Node, target: Any, name: str): # Add calls to edge type-wise `MessagePassing` modules and aggregate # the outputs to node type-wise embeddings afterwards. + module = get_submodule(self.module, target) + check_add_self_loops(module, self.metadata[1]) + # Group edge-wise keys per destination: key_name, keys_per_dst = {}, defaultdict(list) for key in self.metadata[1]: diff --git a/torch_geometric/nn/to_hetero_with_bases_transformer.py b/torch_geometric/nn/to_hetero_with_bases_transformer.py index 1b5fd62cb79c..8c0be95d5d2b 100644 --- a/torch_geometric/nn/to_hetero_with_bases_transformer.py +++ b/torch_geometric/nn/to_hetero_with_bases_transformer.py @@ -11,10 +11,7 @@ from torch_geometric.nn.dense import Linear from torch_geometric.nn.fx import Transformer from torch_geometric.typing import EdgeType, Metadata, NodeType -from torch_geometric.utils.hetero import ( - check_add_self_loops, - get_unused_node_types, -) +from torch_geometric.utils.hetero import get_unused_node_types try: from torch.fx import Graph, GraphModule, Node @@ -131,7 +128,6 @@ def forward(self, x, edge_index): debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ - transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases, in_channels, input_map, debug) return transformer.transform() @@ -148,7 +144,6 @@ def __init__( debug: bool = False, ): super().__init__(module, input_map, debug) - check_add_self_loops(module, metadata[1]) unused_node_types = get_unused_node_types(*metadata) if len(unused_node_types) > 0: diff --git a/torch_geometric/utils/hetero.py b/torch_geometric/utils/hetero.py index f2fe115008c7..b34ab47c4d5a 100644 --- a/torch_geometric/utils/hetero.py +++ b/torch_geometric/utils/hetero.py @@ -50,12 +50,9 @@ def get_unused_node_types(node_types: List[NodeType], def check_add_self_loops(module: torch.nn.Module, edge_types: List[EdgeType]): - edge_types = [ - edge_type for edge_type in edge_types if edge_type[0] != edge_type[-1] - ] - if len(edge_types) > 0 and hasattr( - module, "add_self_loops") and module.add_self_loops: + is_bipartite = any([key[0] != key[-1] for key in edge_types]) + if is_bipartite and getattr(module, 'add_self_loops', False): raise ValueError( - f"'add_self_loops' attribute set to 'True' on module {module} " + f"'add_self_loops' attribute set to 'True' on module '{module}' " f"for use with edge type(s) '{edge_types}'. This will lead to " f"incorrect message passing results.")