diff --git a/CHANGELOG.md b/CHANGELOG.md index d013a96a5c67..0ea5e1fe9492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) - Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816)) - Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807)) @@ -37,7 +38,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) -- Added `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) ### Changed - Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present - Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828)) diff --git a/torch_geometric/nn/dense/diff_pool.py b/torch_geometric/nn/dense/diff_pool.py index 987dba1d43d2..6e6abb92e0cf 100644 --- a/torch_geometric/nn/dense/diff_pool.py +++ b/torch_geometric/nn/dense/diff_pool.py @@ -44,8 +44,9 @@ def dense_diff_pool(x, adj, s, mask=None, normalize=True): mask (BoolTensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) - normalize (Bool, optional): Normalization indicator - :if True, will divide link_loss by size of graph. (default: True) + normalize (bool, optional): If set to :obj:`False`, the link + prediction loss is not divided by :obj:`adj.numel()`. + (default: :obj:`True`) :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, :class:`Tensor`)