From 292e289839a5a94bbc04494bcab3f4ac95086fcc Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 4 Oct 2022 21:25:44 -0700 Subject: [PATCH] [Type Hints] `utils.contains_isolated_nodes` (#5603) * add type hints * changelog * update test --- CHANGELOG.md | 1 + test/utils/test_isolated.py | 6 ++++++ torch_geometric/utils/isolated.py | 8 +++++++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb426ad5d62d..7c4808cb9e7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) diff --git a/test/utils/test_isolated.py b/test/utils/test_isolated.py index 75709468dfe3..cb9c5ea42488 100644 --- a/test/utils/test_isolated.py +++ b/test/utils/test_isolated.py @@ -1,5 +1,6 @@ import torch +from torch_geometric.testing import is_full_test from torch_geometric.utils import ( contains_isolated_nodes, remove_isolated_nodes, @@ -11,6 +12,11 @@ def test_contains_isolated_nodes(): assert not contains_isolated_nodes(edge_index) assert contains_isolated_nodes(edge_index, num_nodes=3) + if is_full_test(): + jit = torch.jit.script(contains_isolated_nodes) + assert not jit(edge_index) + assert jit(edge_index, num_nodes=3) + edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]]) assert contains_isolated_nodes(edge_index) diff --git a/torch_geometric/utils/isolated.py b/torch_geometric/utils/isolated.py index bb0f7076519c..6f7fb2c57fff 100644 --- a/torch_geometric/utils/isolated.py +++ b/torch_geometric/utils/isolated.py @@ -1,11 +1,17 @@ +from typing import Optional + import torch +from torch import Tensor from torch_geometric.utils import remove_self_loops, segregate_self_loops from .num_nodes import maybe_num_nodes -def contains_isolated_nodes(edge_index, num_nodes=None): +def contains_isolated_nodes( + edge_index: Tensor, + num_nodes: Optional[int] = None, +) -> bool: r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains isolated nodes.