Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing HGTConv with FastHGTConv #7117

Merged
merged 18 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions test/nn/conv/test_fast_hgt_conv.py

This file was deleted.

35 changes: 1 addition & 34 deletions test/nn/conv/test_hgt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch_geometric.typing
from torch_geometric.data import HeteroData
from torch_geometric.nn import FastHGTConv, HGTConv
from torch_geometric.nn import HGTConv
from torch_geometric.profile import benchmark
from torch_geometric.testing import get_random_edge_index
from torch_geometric.typing import SparseTensor
Expand Down Expand Up @@ -178,39 +178,6 @@ def test_hgt_conv_out_of_place():
assert x_dict['paper'].size() == (6, 32)


def test_fast_hgt_conv():
x_dict = {
'v0': torch.randn(5, 4),
'v1': torch.randn(5, 4),
'v2': torch.randn(5, 4),
}

edge_index_dict = {
('v0', 'e1', 'v0'): torch.randint(0, 5, size=(2, 10)),
('v0', 'e2', 'v1'): torch.randint(0, 5, size=(2, 10)),
}

metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))
conv1 = HGTConv(4, 2, metadata)
conv2 = FastHGTConv(4, 2, metadata)

# Make parameters match:
for my_param in conv1.parameters():
my_param.data.fill_(1)
for og_param in conv2.parameters():
og_param.data.fill_(1)

out_dict1 = conv1(x_dict, edge_index_dict)
out_dict2 = conv2(x_dict, edge_index_dict)

assert len(out_dict1) == len(out_dict2)
for key, out1 in out_dict1.items():
out2 = out_dict2[key]
if out1 is None and out2 is None:
continue
assert torch.allclose(out1, out2)


if __name__ == '__main__':
import argparse

Expand Down
2 changes: 0 additions & 2 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from .pdn_conv import PDNConv
from .general_conv import GeneralConv
from .hgt_conv import HGTConv
from .fast_hgt_conv import FastHGTConv
from .heat_conv import HEATConv
from .hetero_conv import HeteroConv
from .han_conv import HANConv
Expand Down Expand Up @@ -119,7 +118,6 @@
'PDNConv',
'GeneralConv',
'HGTConv',
'FastHGTConv',
'HEATConv',
'HeteroConv',
'HANConv',
Expand Down
204 changes: 0 additions & 204 deletions torch_geometric/nn/conv/fast_hgt_conv.py

This file was deleted.

19 changes: 17 additions & 2 deletions torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import warnings
from collections import defaultdict
from typing import Dict, Optional
from typing import Dict, List, Optional

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.hgt_conv import group
from torch_geometric.nn.module_dict import ModuleDict
from torch_geometric.typing import Adj, EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops


def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
if len(xs) == 0:
return None
elif aggr is None:
return torch.stack(xs, dim=1)
elif len(xs) == 1:
return xs[0]
elif aggr == "cat":
return torch.cat(xs, dim=-1)
else:
out = torch.stack(xs, dim=0)
out = getattr(torch, aggr)(out, dim=0)
out = out[0] if isinstance(out, tuple) else out
return out


class HeteroConv(torch.nn.Module):
r"""A generic wrapper for computing graph convolution on heterogeneous
graphs.
Expand Down
Loading