Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Pooling Framework 🚀 #6455

Closed
4 tasks done
rusty1s opened this issue Jan 18, 2023 · 9 comments
Closed
4 tasks done

Graph Pooling Framework 🚀 #6455

rusty1s opened this issue Jan 18, 2023 · 9 comments

Comments

@rusty1s
Copy link
Member

rusty1s commented Jan 18, 2023

🚀 The feature, motivation and pitch

The "Understanding Pooling in Graph Neural Networks" paper introduces a simple framework to unify (most) graph pooling approaches:

  1. Select: Selects input nodes to map to supernodes
  2. Reduce: Reduces the supernodes to singletons
  3. Connect: Decides how the new nodes are connected

Currently, the pooling operators included in PyG are implemented isolated from each other, and there exists a lot of repetitive code and different interfaces, which makes their application confusing and challenging. By following the above approach, we can unify existing implementations and can accelerate new research on graph pooling. For example, we can introduce base classes for each of the aforementioned steps:

class Select(torch.nn.Module):
    def forward(self, *args, **kwargs) -> torch.Tensor:
         """Returns a bipartite `edge_index` mapping input nodes to supernodes."""

class Reduce:  # No need -> we can utilize `nn.aggr` for this
    pass

class Connect(torch.nn.Module):
    def forward(self, cluster_index: torch.Tensor, edge_index: torch.Tensor, *args, **kwargs):
        """Returns a coarsened graph."""

With this, e.g., graclus can be moved to a Select operator and TopK pooling can be moved to a Connect operator.

Relevant twitter thread: https://twitter.com/riceasphait/status/1447867635442585601

Tasks:

Add concrete Select and Connect classes and update implementations of.

@wsad1
Copy link
Member

wsad1 commented Jan 24, 2023

Would it be useful to add a Pooling class that all pooling operators inherit from.

class Pooling(torch.nn.Module):
    def forward(self, x, edge_index, *args, **kwargs):
          mapping = self.select(x, edge_index, *args, **kwargs)
          x = self.reduce(x, mapping)
          edge_index = self.connect(mapping, edge_index, *args, **kwargs)

self.select, reduce and connect would be objects defined in the child class like graclus or TopK.

@rusty1s
Copy link
Member Author

rusty1s commented Jan 24, 2023

Yes, definitely :)

@wsad1 wsad1 assigned wsad1 and unassigned wsad1 Jan 28, 2023
@rusty1s rusty1s pinned this issue Jan 28, 2023
rusty1s added a commit that referenced this issue Jan 28, 2023
#6455
Adds base classes `Select`, `Connect` and `Pooling` for the Graph
pooling framework.
TODO in this PR
1. Implement `TopkPooling` with this framework to verify the interface.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
@danielegrattarola
Copy link
Contributor

Hey, first author of the paper here!

This is super, if help is still wanted (as I see on the tags) I'd be happy to participate!

Worth mentioning that there is an implementation of SRC in Spektral and a few layers implemented with it, it might be possible to translate the implementation from TF to Torch/PyG.

@wsad1
Copy link
Member

wsad1 commented May 6, 2023

@danielegrattarola firstly great paper. Sorry for the late reply.
But if you are still interested we could use your help moving SAGPooling and PANPooling.
Just refactor topkpooling here.

rusty1s added a commit that referenced this issue May 8, 2023
Towards #6455.
This PR makes the following changes
1. Removes the `Pooling` base class. The current implementation isn't
flexible enough to support all pooling operators. Instead pooling
operators will implement their own forward method, using `Select` ,
`Connect` and `Aggregate` operators.
2. Updated the `Select` operator. It now returns `SelectOutput` which
contains `node_index`, `cluster_index`, `num_clusters` and `weight`.
Where `weight` is the weight given to a node assignment to a cluster.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
rusty1s added a commit that referenced this issue May 8, 2023
Towards #6455.
Do not review before #7307 is merged.

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
@puririshi98
Copy link
Contributor

benchmark of topK function
Screen Shot 2023-06-07 at 11 18 14 AM
https://github.com/puririshi98/rgcn_pyg_lib_forward_bench/blob/main/topK_bench.py

@puririshi98
Copy link
Contributor

topk microbench PR:
#7549

@puririshi98
Copy link
Contributor

diff pool microbench PR:
#7550

@puririshi98
Copy link
Contributor

#7361
PR is ready to merge

@puririshi98
Copy link
Contributor

#7625

@rusty1s rusty1s unpinned this issue Jun 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants