diff --git a/test/nn/conv/test_fast_hgt_conv.py b/test/nn/conv/test_fast_hgt_conv.py deleted file mode 100644 index 55d36bf49225..000000000000 --- a/test/nn/conv/test_fast_hgt_conv.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch - -from torch_geometric.nn import FastHGTConv -from torch_geometric.testing import get_random_edge_index - - -def test_fast_hgt_conv(): - x_dict = { - 'author': torch.randn(4, 16), - 'paper': torch.randn(6, 16), - } - edge_index = get_random_edge_index(4, 6, num_edges=20) - - edge_index_dict = { - ('author', 'writes', 'paper'): edge_index, - ('paper', 'written_by', 'author'): edge_index.flip([0]), - } - - metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) - - conv = FastHGTConv(16, 16, metadata, heads=2) - assert str(conv) == 'FastHGTConv(-1, 16, heads=2)' - out_dict1 = conv(x_dict, edge_index_dict) - assert len(out_dict1) == 2 - assert out_dict1['author'].size() == (4, 16) - assert out_dict1['paper'].size() == (6, 16) diff --git a/test/nn/conv/test_hgt_conv.py b/test/nn/conv/test_hgt_conv.py index 65e07d5ded03..bb41663d2e4a 100644 --- a/test/nn/conv/test_hgt_conv.py +++ b/test/nn/conv/test_hgt_conv.py @@ -2,7 +2,7 @@ import torch_geometric.typing from torch_geometric.data import HeteroData -from torch_geometric.nn import FastHGTConv, HGTConv +from torch_geometric.nn import HGTConv from torch_geometric.profile import benchmark from torch_geometric.testing import get_random_edge_index from torch_geometric.typing import SparseTensor @@ -178,39 +178,6 @@ def test_hgt_conv_out_of_place(): assert x_dict['paper'].size() == (6, 32) -def test_fast_hgt_conv(): - x_dict = { - 'v0': torch.randn(5, 4), - 'v1': torch.randn(5, 4), - 'v2': torch.randn(5, 4), - } - - edge_index_dict = { - ('v0', 'e1', 'v0'): torch.randint(0, 5, size=(2, 10)), - ('v0', 'e2', 'v1'): torch.randint(0, 5, size=(2, 10)), - } - - metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) - conv1 = HGTConv(4, 2, metadata) - conv2 = FastHGTConv(4, 2, metadata) - - # Make parameters match: - for my_param in conv1.parameters(): - my_param.data.fill_(1) - for og_param in conv2.parameters(): - og_param.data.fill_(1) - - out_dict1 = conv1(x_dict, edge_index_dict) - out_dict2 = conv2(x_dict, edge_index_dict) - - assert len(out_dict1) == len(out_dict2) - for key, out1 in out_dict1.items(): - out2 = out_dict2[key] - if out1 is None and out2 is None: - continue - assert torch.allclose(out1, out2) - - if __name__ == '__main__': import argparse diff --git a/torch_geometric/nn/conv/__init__.py b/torch_geometric/nn/conv/__init__.py index a0998c3bf5d5..9410e03f0b3d 100644 --- a/torch_geometric/nn/conv/__init__.py +++ b/torch_geometric/nn/conv/__init__.py @@ -51,7 +51,6 @@ from .pdn_conv import PDNConv from .general_conv import GeneralConv from .hgt_conv import HGTConv -from .fast_hgt_conv import FastHGTConv from .heat_conv import HEATConv from .hetero_conv import HeteroConv from .han_conv import HANConv @@ -119,7 +118,6 @@ 'PDNConv', 'GeneralConv', 'HGTConv', - 'FastHGTConv', 'HEATConv', 'HeteroConv', 'HANConv', diff --git a/torch_geometric/nn/conv/fast_hgt_conv.py b/torch_geometric/nn/conv/fast_hgt_conv.py deleted file mode 100644 index 5c0d772a897f..000000000000 --- a/torch_geometric/nn/conv/fast_hgt_conv.py +++ /dev/null @@ -1,204 +0,0 @@ -import math -from typing import Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.nn import Parameter - -from torch_geometric.nn.conv import MessagePassing -from torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear -from torch_geometric.nn.inits import ones -from torch_geometric.nn.parameter_dict import ParameterDict -from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType -from torch_geometric.utils import softmax -from torch_geometric.utils.hetero import construct_bipartite_edge_index - - -class FastHGTConv(MessagePassing): - r"""See :class:`HGTConv`.""" - def __init__( - self, - in_channels: Union[int, Dict[str, int]], - out_channels: int, - metadata: Metadata, - heads: int = 1, - **kwargs, - ): - super().__init__(aggr='add', node_dim=0, **kwargs) - - if out_channels % heads != 0: - raise ValueError(f"'out_channels' (got {out_channels}) must be " - f"divisible by the number of heads (got {heads})") - - if not isinstance(in_channels, dict): - in_channels = {node_type: in_channels for node_type in metadata[0]} - - self.in_channels = in_channels - self.out_channels = out_channels - self.heads = heads - self.node_types = metadata[0] - self.edge_types = metadata[1] - self.dst_node_types = list(set(metadata[1][1])) - self.src_types = [edge_type[0] for edge_type in self.edge_types] - - self.kqv_lin = HeteroDictLinear(self.in_channels, - self.out_channels * 3) - - self.out_lin = HeteroDictLinear(self.out_channels, self.out_channels, - types=self.node_types) - - dim = out_channels // heads - num_types = heads * len(self.edge_types) - - self.k_rel = HeteroLinear(dim, dim, num_types, is_sorted=True, - bias=False) - self.v_rel = HeteroLinear(dim, dim, num_types, is_sorted=True, - bias=False) - - self.skip = ParameterDict({ - node_type: Parameter(torch.Tensor(1)) - for node_type in self.node_types - }) - - self.p_rel = ParameterDict() - for edge_type in self.edge_types: - edge_type = '__'.join(edge_type) - self.p_rel[edge_type] = Parameter(torch.Tensor(1, heads)) - - self.reset_parameters() - - def reset_parameters(self): - super().reset_parameters() - self.kqv_lin.reset_parameters() - self.out_lin.reset_parameters() - self.k_rel.reset_parameters() - self.v_rel.reset_parameters() - ones(self.skip) - ones(self.p_rel) - - def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: - """Concatenates a dictionary of features.""" - cumsum = 0 - outs: List[Tensor] = [] - offset: Dict[str, int] = {} - for key, x in x_dict.items(): - outs.append(x) - offset[key] = cumsum - cumsum += x.size(0) - return torch.cat(outs, dim=0), offset - - def _construct_src_node_feat( - self, - k_dict: Dict[str, Tensor], - v_dict: Dict[str, Tensor], - ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]: - """Constructs the source node representations.""" - count = 0 - cumsum = 0 - H, D = self.heads, self.out_channels // self.heads - - # Flatten into a single tensor with shape [num_edge_types * heads, D]: - ks: List[Tensor] = [] - vs: List[Tensor] = [] - type_list: List[int] = [] - offset: Dict[EdgeType] = {} - for edge_type in self.edge_types: - src, _, _ = edge_type - - ks.append(k_dict[src].reshape(-1, D)) - vs.append(v_dict[src].reshape(-1, D)) - - N = k_dict[src].size(0) - for _ in range(H): - type_list.append(torch.full((N, ), count, dtype=torch.long)) - count += 1 - offset[edge_type] = cumsum - cumsum += N - - type_vec = torch.cat(type_list, dim=0) - k = self.k_rel(torch.cat(ks, dim=0), type_vec).view(-1, H, D) - v = self.v_rel(torch.cat(vs, dim=0), type_vec).view(-1, H, D) - - return k, v, offset - - def forward( - self, - x_dict: Dict[NodeType, Tensor], - edge_index_dict: Dict[EdgeType, Adj] # Support both. - ) -> Dict[NodeType, Optional[Tensor]]: - r"""Runs the forward pass of the module. - - Args: - x_dict (Dict[str, torch.Tensor]): A dictionary holding input node - features for each individual node type. - edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A - dictionary holding graph connectivity information for each - individual edge type, either as a :class:`torch.Tensor` of - shape :obj:`[2, num_edges]` or a - :class:`torch_sparse.SparseTensor`. - - :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node - embeddings for each node type. - In case a node type does not receive any message, its output will - be set to :obj:`None`. - """ - F = self.out_channels - H = self.heads - D = F // H - - k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} - - # Compute K, Q, V over node types: - kqv_dict = self.kqv_lin(x_dict) - for key, val in kqv_dict.items(): - k_dict[key] = val[:, :F].view(-1, H, D) - q_dict[key] = val[:, F:2 * F].view(-1, H, D) - v_dict[key] = val[:, 2 * F:].view(-1, H, D) - - q, dst_offset = self._cat(q_dict) - k, v, src_offset = self._construct_src_node_feat(k_dict, v_dict) - - edge_index, edge_attr = construct_bipartite_edge_index( - edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel) - - out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr, - size=None) - - # Reconstruct output node embeddings dict: - for node_type, start_offset in dst_offset.items(): - end_offset = start_offset + q_dict[node_type].size(0) - out_dict[node_type] = out[start_offset:end_offset] - - # Transform output node embeddings: - a_dict = self.out_lin({ - k: torch.nn.functional.gelu(v) if v is not None else v - for k, v in out_dict.items() - }) - - # Iterate over node types: - for node_type, out in out_dict.items(): - if out is None or node_type not in self.dst_node_types: - out_dict[node_type] = None - continue - else: - out = a_dict[node_type] - - if out.size(-1) == x_dict[node_type].size(-1): - alpha = self.skip[node_type].sigmoid() - out = alpha * out + (1 - alpha) * x_dict[node_type] - out_dict[node_type] = out - - return out_dict - - def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, - index: Tensor, ptr: Optional[Tensor], - size_i: Optional[int]) -> Tensor: - alpha = (q_i * k_j).sum(dim=-1) * edge_attr - alpha = alpha / math.sqrt(q_i.size(-1)) - alpha = softmax(alpha, index, ptr, size_i) - out = v_j * alpha.view(-1, self.heads, 1) - return out.view(-1, self.out_channels) - - def __repr__(self) -> str: - return (f'{self.__class__.__name__}(-1, {self.out_channels}, ' - f'heads={self.heads})') diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index a59879857169..e1d4d2a9d26e 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -1,17 +1,32 @@ import warnings from collections import defaultdict -from typing import Dict, Optional +from typing import Dict, List, Optional import torch from torch import Tensor from torch_geometric.nn.conv import MessagePassing -from torch_geometric.nn.conv.hgt_conv import group from torch_geometric.nn.module_dict import ModuleDict from torch_geometric.typing import Adj, EdgeType, NodeType from torch_geometric.utils.hetero import check_add_self_loops +def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]: + if len(xs) == 0: + return None + elif aggr is None: + return torch.stack(xs, dim=1) + elif len(xs) == 1: + return xs[0] + elif aggr == "cat": + return torch.cat(xs, dim=-1) + else: + out = torch.stack(xs, dim=0) + out = getattr(torch, aggr)(out, dim=0) + out = out[0] if isinstance(out, tuple) else out + return out + + class HeteroConv(torch.nn.Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. diff --git a/torch_geometric/nn/conv/hgt_conv.py b/torch_geometric/nn/conv/hgt_conv.py index 1f5bc8692a54..ccdeea16bce6 100644 --- a/torch_geometric/nn/conv/hgt_conv.py +++ b/torch_geometric/nn/conv/hgt_conv.py @@ -1,34 +1,17 @@ import math -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing -from torch_geometric.nn.dense import Linear -from torch_geometric.nn.inits import glorot, ones, reset -from torch_geometric.nn.module_dict import ModuleDict +from torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear +from torch_geometric.nn.inits import ones from torch_geometric.nn.parameter_dict import ParameterDict -from torch_geometric.typing import EdgeType, Metadata, NodeType, SparseTensor +from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType from torch_geometric.utils import softmax - - -def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]: - if len(xs) == 0: - return None - elif aggr is None: - return torch.stack(xs, dim=1) - elif len(xs) == 1: - return xs[0] - elif aggr == "cat": - return torch.cat(xs, dim=-1) - else: - out = torch.stack(xs, dim=0) - out = getattr(torch, aggr)(out, dim=0) - out = out[0] if isinstance(out, tuple) else out - return out +from torch_geometric.utils.hetero import construct_bipartite_edge_index class HGTConv(MessagePassing): @@ -72,7 +55,6 @@ def __init__( out_channels: int, metadata: Metadata, heads: int = 1, - group: str = "sum", **kwargs, ): super().__init__(aggr='add', node_dim=0, **kwargs) @@ -87,48 +69,95 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.heads = heads - self.group = group - - self.k_lin = ModuleDict() - self.q_lin = ModuleDict() - self.v_lin = ModuleDict() - self.a_lin = ModuleDict() - self.skip = ParameterDict() - for node_type, in_channels in self.in_channels.items(): - self.k_lin[node_type] = Linear(in_channels, out_channels) - self.q_lin[node_type] = Linear(in_channels, out_channels) - self.v_lin[node_type] = Linear(in_channels, out_channels) - self.a_lin[node_type] = Linear(out_channels, out_channels) - self.skip[node_type] = Parameter(torch.Tensor(1)) - - self.a_rel = ParameterDict() - self.m_rel = ParameterDict() - self.p_rel = ParameterDict() + self.node_types = metadata[0] + self.edge_types = metadata[1] + self.dst_node_types = list(set(metadata[1][1])) + self.src_types = [edge_type[0] for edge_type in self.edge_types] + + self.kqv_lin = HeteroDictLinear(self.in_channels, + self.out_channels * 3) + + self.out_lin = HeteroDictLinear(self.out_channels, self.out_channels, + types=self.node_types) + dim = out_channels // heads - for edge_type in metadata[1]: + num_types = heads * len(self.edge_types) + + self.k_rel = HeteroLinear(dim, dim, num_types, is_sorted=True, + bias=False) + self.v_rel = HeteroLinear(dim, dim, num_types, is_sorted=True, + bias=False) + + self.skip = ParameterDict({ + node_type: Parameter(torch.Tensor(1)) + for node_type in self.node_types + }) + + self.p_rel = ParameterDict() + for edge_type in self.edge_types: edge_type = '__'.join(edge_type) - self.a_rel[edge_type] = Parameter(torch.Tensor(heads, dim, dim)) - self.m_rel[edge_type] = Parameter(torch.Tensor(heads, dim, dim)) - self.p_rel[edge_type] = Parameter(torch.Tensor(heads)) + self.p_rel[edge_type] = Parameter(torch.Tensor(1, heads)) self.reset_parameters() def reset_parameters(self): super().reset_parameters() - reset(self.k_lin) - reset(self.q_lin) - reset(self.v_lin) - reset(self.a_lin) + self.kqv_lin.reset_parameters() + self.out_lin.reset_parameters() + self.k_rel.reset_parameters() + self.v_rel.reset_parameters() ones(self.skip) ones(self.p_rel) - glorot(self.a_rel) - glorot(self.m_rel) + + def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: + """Concatenates a dictionary of features.""" + cumsum = 0 + outs: List[Tensor] = [] + offset: Dict[str, int] = {} + for key, x in x_dict.items(): + outs.append(x) + offset[key] = cumsum + cumsum += x.size(0) + return torch.cat(outs, dim=0), offset + + def _construct_src_node_feat( + self, + k_dict: Dict[str, Tensor], + v_dict: Dict[str, Tensor], + ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]: + """Constructs the source node representations.""" + count = 0 + cumsum = 0 + H, D = self.heads, self.out_channels // self.heads + + # Flatten into a single tensor with shape [num_edge_types * heads, D]: + ks: List[Tensor] = [] + vs: List[Tensor] = [] + type_list: List[int] = [] + offset: Dict[EdgeType] = {} + for edge_type in self.edge_types: + src, _, _ = edge_type + + ks.append(k_dict[src].reshape(-1, D)) + vs.append(v_dict[src].reshape(-1, D)) + + N = k_dict[src].size(0) + for _ in range(H): + type_list.append(torch.full((N, ), count, dtype=torch.long)) + count += 1 + offset[edge_type] = cumsum + cumsum += N + + type_vec = torch.cat(type_list, dim=0) + k = self.k_rel(torch.cat(ks, dim=0), type_vec).view(-1, H, D) + v = self.v_rel(torch.cat(vs, dim=0), type_vec).view(-1, H, D) + + return k, v, offset def forward( self, x_dict: Dict[NodeType, Tensor], - edge_index_dict: Union[Dict[EdgeType, Tensor], - Dict[EdgeType, SparseTensor]] # Support both. + edge_index_dict: Dict[EdgeType, Adj] # Support both. ) -> Dict[NodeType, Optional[Tensor]]: r"""Runs the forward pass of the module. @@ -146,42 +175,47 @@ def forward( In case a node type does not receive any message, its output will be set to :obj:`None`. """ - H, D = self.heads, self.out_channels // self.heads + F = self.out_channels + H = self.heads + D = F // H k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} - # Iterate over node-types: - for node_type, x in x_dict.items(): - k_dict[node_type] = self.k_lin[node_type](x).view(-1, H, D) - q_dict[node_type] = self.q_lin[node_type](x).view(-1, H, D) - v_dict[node_type] = self.v_lin[node_type](x).view(-1, H, D) - out_dict[node_type] = [] + # Compute K, Q, V over node types: + kqv_dict = self.kqv_lin(x_dict) + for key, val in kqv_dict.items(): + k_dict[key] = val[:, :F].view(-1, H, D) + q_dict[key] = val[:, F:2 * F].view(-1, H, D) + v_dict[key] = val[:, 2 * F:].view(-1, H, D) - # Iterate over edge-types: - for edge_type, edge_index in edge_index_dict.items(): - src_type, _, dst_type = edge_type - edge_type = '__'.join(edge_type) + q, dst_offset = self._cat(q_dict) + k, v, src_offset = self._construct_src_node_feat(k_dict, v_dict) - a_rel = self.a_rel[edge_type] - k = (k_dict[src_type].transpose(0, 1) @ a_rel).transpose(1, 0) + edge_index, edge_attr = construct_bipartite_edge_index( + edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel) - m_rel = self.m_rel[edge_type] - v = (v_dict[src_type].transpose(0, 1) @ m_rel).transpose(1, 0) + out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr, + size=None) - # propagate_type: (k: Tensor, q: Tensor, v: Tensor, rel: Tensor) - out = self.propagate(edge_index, k=k, q=q_dict[dst_type], v=v, - rel=self.p_rel[edge_type], size=None) - out_dict[dst_type].append(out) + # Reconstruct output node embeddings dict: + for node_type, start_offset in dst_offset.items(): + end_offset = start_offset + q_dict[node_type].size(0) + out_dict[node_type] = out[start_offset:end_offset] - # Iterate over node-types: - for node_type, outs in out_dict.items(): - out = group(outs, self.group) + # Transform output node embeddings: + a_dict = self.out_lin({ + k: torch.nn.functional.gelu(v) if v is not None else v + for k, v in out_dict.items() + }) - if out is None: + # Iterate over node types: + for node_type, out in out_dict.items(): + if node_type not in self.dst_node_types: out_dict[node_type] = None continue + else: + out = a_dict[node_type] - out = self.a_lin[node_type](F.gelu(out)) if out.size(-1) == x_dict[node_type].size(-1): alpha = self.skip[node_type].sigmoid() out = alpha * out + (1 - alpha) * x_dict[node_type] @@ -189,11 +223,10 @@ def forward( return out_dict - def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, rel: Tensor, + def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: - - alpha = (q_i * k_j).sum(dim=-1) * rel + alpha = (q_i * k_j).sum(dim=-1) * edge_attr alpha = alpha / math.sqrt(q_i.size(-1)) alpha = softmax(alpha, index, ptr, size_i) out = v_j * alpha.view(-1, self.heads, 1) diff --git a/torch_geometric/utils/hetero.py b/torch_geometric/utils/hetero.py index d1920271b7d3..ccd6579dc444 100644 --- a/torch_geometric/utils/hetero.py +++ b/torch_geometric/utils/hetero.py @@ -77,7 +77,7 @@ def construct_bipartite_edge_index( :class:`torch_sparse.SparseTensor`. src_offset_dict (Dict[Tuple[str, str, str], int]): A dictionary of offsets to apply to the source node type for each edge type. - src_offset_dict (Dict[str, int]): A dictionary of offsets to apply for + dst_offset_dict (Dict[str, int]): A dictionary of offsets to apply for destination node types. edge_attr_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding edge features for each individual edge type. @@ -92,9 +92,14 @@ def construct_bipartite_edge_index( # TODO Add support for SparseTensor w/o converting. is_sparse = isinstance(edge_index, SparseTensor) + is_native_sparse = isinstance(edge_index, Tensor) and 'sparse' in str( + edge_index.layout) if is_sparse: col, row, _ = edge_index.coo() edge_index = torch.stack([row, col], dim=0) + elif is_native_sparse: + edge_index = torch.tensor( + edge_index.to_sparse_coo().indices()).flip(0) else: edge_index = edge_index.clone()