Skip to content

Commit

Permalink
[Type Hints] utils.contains_isolated_nodes (#5603)
Browse files Browse the repository at this point in the history
* add type hints

* changelog

* update test
  • Loading branch information
rusty1s authored Oct 5, 2022
1 parent 14fa128 commit 292e289
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
6 changes: 6 additions & 0 deletions test/utils/test_isolated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from torch_geometric.testing import is_full_test
from torch_geometric.utils import (
contains_isolated_nodes,
remove_isolated_nodes,
Expand All @@ -11,6 +12,11 @@ def test_contains_isolated_nodes():
assert not contains_isolated_nodes(edge_index)
assert contains_isolated_nodes(edge_index, num_nodes=3)

if is_full_test():
jit = torch.jit.script(contains_isolated_nodes)
assert not jit(edge_index)
assert jit(edge_index, num_nodes=3)

edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]])
assert contains_isolated_nodes(edge_index)

Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/utils/isolated.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.utils import remove_self_loops, segregate_self_loops

from .num_nodes import maybe_num_nodes


def contains_isolated_nodes(edge_index, num_nodes=None):
def contains_isolated_nodes(
edge_index: Tensor,
num_nodes: Optional[int] = None,
) -> bool:
r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains
isolated nodes.
Expand Down

0 comments on commit 292e289

Please sign in to comment.