Skip to content

Commit

Permalink
[Type Hints] utils.is_undirected and utils.to_undirected (pyg-tea…
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored and JakubPietrakIntel committed Nov 25, 2022
1 parent 43358f8 commit c084ef3
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
18 changes: 18 additions & 0 deletions test/utils/test_undirected.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from torch import Tensor

from torch_geometric.testing import is_full_test
from torch_geometric.utils import is_undirected, to_undirected


Expand All @@ -18,10 +20,26 @@ def test_is_undirected():

assert not is_undirected(torch.stack([row, col], dim=0))

if is_full_test():

@torch.jit.script
def jit(edge_index: Tensor) -> bool:
return is_undirected(edge_index)

assert not jit(torch.stack([row, col], dim=0))


def test_to_undirected():
row = torch.tensor([0, 1, 1])
col = torch.tensor([1, 0, 2])

edge_index = to_undirected(torch.stack([row, col], dim=0))
assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]

if is_full_test():

@torch.jit.script
def jit(edge_index: Tensor) -> Tensor:
return to_undirected(edge_index)

assert torch.equal(jit(torch.stack([row, col], dim=0)), edge_index)
11 changes: 6 additions & 5 deletions torch_geometric/nn/dense/diff_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@


def dense_diff_pool(
x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None,
normalize: Optional[bool] = True
x: Tensor,
adj: Tensor,
s: Tensor,
mask: Optional[Tensor] = None,
normalize: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""The differentiable pooling operator from the `"Hierarchical Graph
Representation Learning with Differentiable Pooling"
Expand Down Expand Up @@ -76,8 +79,6 @@ def dense_diff_pool(
if normalize is True:
link_loss = link_loss / adj.numel()

# Moved EPS from global to local variable for TorchScript support
EPS = 1e-15
ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()
ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean()

return out, out_adj, link_loss, ent_loss
32 changes: 27 additions & 5 deletions torch_geometric/utils/sort_edge_index.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor

from .num_nodes import maybe_num_nodes


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
sort_by_row=True):
# type: (Tensor, Optional[bool], Optional[int], bool) -> Tensor # noqa
pass


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
sort_by_row=True):
# type: (Tensor, Tensor, Optional[int], bool) -> Tuple[Tensor, Tensor] # noqa
pass


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None,
sort_by_row=True):
# type: (Tensor, List[Tensor], Optional[int], bool) -> Tuple[Tensor, List[Tensor]] # noqa
pass


def sort_edge_index(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
num_nodes: Optional[int] = None,
sort_by_row: bool = True,
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
Expand Down Expand Up @@ -53,9 +75,9 @@ def sort_edge_index(

edge_index = edge_index[:, perm]

if edge_attr is None:
return edge_index
elif isinstance(edge_attr, Tensor):
if isinstance(edge_attr, Tensor):
return edge_index, edge_attr[perm]
else:
elif isinstance(edge_attr, (list, tuple)):
return edge_index, [e[perm] for e in edge_attr]
else:
return edge_index
69 changes: 53 additions & 16 deletions torch_geometric/utils/undirected.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,21 @@
from .num_nodes import maybe_num_nodes


@torch.jit._overload
def is_undirected(edge_index, edge_attr=None, num_nodes=None):
# type: (Tensor, Optional[Tensor], Optional[int]) -> bool # noqa
pass


@torch.jit._overload
def is_undirected(edge_index, edge_attr=None, num_nodes=None):
# type: (Tensor, List[Tensor], Optional[int]) -> bool # noqa
pass


def is_undirected(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
num_nodes: Optional[int] = None,
) -> bool:
r"""Returns :obj:`True` if the graph given by :attr:`edge_index` is
Expand Down Expand Up @@ -42,31 +54,56 @@ def is_undirected(
"""
num_nodes = maybe_num_nodes(edge_index, num_nodes)

edge_attr = [] if edge_attr is None else edge_attr
edge_attr = [edge_attr] if isinstance(edge_attr, Tensor) else edge_attr
edge_attrs: List[Tensor] = []
if isinstance(edge_attr, Tensor):
edge_attrs.append(edge_attr)
elif isinstance(edge_attr, (list, tuple)):
edge_attrs = edge_attr

edge_index1, edge_attr1 = sort_edge_index(
edge_index1, edge_attrs1 = sort_edge_index(
edge_index,
edge_attr,
edge_attrs,
num_nodes=num_nodes,
sort_by_row=True,
)
edge_index2, edge_attr2 = sort_edge_index(
edge_index1,
edge_attr1,
edge_index2, edge_attrs2 = sort_edge_index(
edge_index,
edge_attrs,
num_nodes=num_nodes,
sort_by_row=False,
)

return (bool(torch.all(edge_index1[0] == edge_index2[1]))
and bool(torch.all(edge_index1[1] == edge_index2[0])) and all([
torch.all(e == e_T) for e, e_T in zip(edge_attr1, edge_attr2)
]))
if not torch.equal(edge_index1[0], edge_index2[1]):
return False
if not torch.equal(edge_index1[1], edge_index2[0]):
return False
for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2):
if not torch.equal(edge_attr1, edge_attr2):
return False
return True


@torch.jit._overload
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
# type: (Tensor, Optional[bool], Optional[int], str) -> Tensor # noqa
pass


@torch.jit._overload
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
# type: (Tensor, Tensor, Optional[int], str) -> Tuple[Tensor, Tensor] # noqa
pass


@torch.jit._overload
def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"):
# type: (Tensor, List[Tensor], Optional[int], str) -> Tuple[Tensor, List[Tensor]] # noqa
pass


def to_undirected(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
edge_attr: Union[Optional[Tensor], List[Tensor]] = None,
num_nodes: Optional[int] = None,
reduce: str = "add",
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
Expand Down Expand Up @@ -116,13 +153,13 @@ def to_undirected(
edge_attr = None
num_nodes = edge_attr

row, col = edge_index
row, col = edge_index[0], edge_index[1]
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
edge_index = torch.stack([row, col], dim=0)

if edge_attr is not None and isinstance(edge_attr, Tensor):
if isinstance(edge_attr, Tensor):
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
elif edge_attr is not None:
elif isinstance(edge_attr, (list, tuple)):
edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr]

return coalesce(edge_index, edge_attr, num_nodes, reduce)

0 comments on commit c084ef3

Please sign in to comment.