Skip to content

Commit

Permalink
[Type Hints] utils.normalized_cut (#5665)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
akash-vartak and rusty1s authored Oct 13, 2022
1 parent 0acbed7 commit c466ffc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
6 changes: 6 additions & 0 deletions test/utils/test_normalized_cut.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from torch_geometric.testing import is_full_test
from torch_geometric.utils import normalized_cut


Expand All @@ -11,3 +12,8 @@ def test_normalized_cut():

output = normalized_cut(torch.stack([row, col], dim=0), edge_attr)
assert output.tolist() == expected_output

if is_full_test():
jit = torch.jit.script(normalized_cut)
output = jit(torch.stack([row, col], dim=0), edge_attr)
assert output.tolist() == expected_output
8 changes: 7 additions & 1 deletion torch_geometric/utils/normalized_cut.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Optional

from torch import Tensor

from torch_geometric.utils import degree


def normalized_cut(edge_index, edge_attr, num_nodes: Optional[int] = None):
def normalized_cut(
edge_index: Tensor,
edge_attr: Tensor,
num_nodes: Optional[int] = None,
) -> Tensor:
r"""Computes the normalized cut :math:`\mathbf{e}_{i,j} \cdot
\left( \frac{1}{\deg(i)} + \frac{1}{\deg(j)} \right)` of a weighted graph
given by edge indices and edge attributes.
Expand Down

0 comments on commit c466ffc

Please sign in to comment.