Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to update the trainable parameters for edge models & How did you set the edge scores as masks in the message-passing? #7

Closed
smiles724 opened this issue Feb 21, 2023 · 7 comments

Comments

@smiles724
Copy link

Hi, thanks for sharing the code.

I noticed that you used self.mlp to work on pairs of node representations to obtain the edge scores. Then this edge score is used to select the edges of the causal subgraph.
image

However, there are two confusion questions. (1) you mentioned M_ij is calculated by sigmoid(Z_i^T Z_j) rather than using a parametric network to gett the mask matrix. (2) As far as I am concerned, the parameters of this edge model self.mlp cannot be backpropagated during the training. In other words, its parameters are fixed. Can you please give me some more explanations so that I can understand better how this edge model works?

@Wuyxin
Copy link
Owner

Wuyxin commented Feb 21, 2023

Hi Thanks for the questions!

  1. Yes, we used an MLP instead of InnerProduct to decode the edge probability. This additional module makes sure that there is extra capacity (besides the shared node encoder) for the model to learn what is causal feature and what is not.
  2. The parameters of the MLP can be updated by setting the generated mask on the GNN: https://github.com/Wuyxin/DIR-GNN/blob/main/train/spmotif_dir.py#L195
    Specifically, we did it by
    module.__edge_mask__ = mask
    In the internal implementation of pyg (2.0.0), they further weight the edge score on the message-passing process:
    if self.__explain__: edge_mask = self.__edge_mask__.sigmoid() ... out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
    This includes the sigmoid function that corresponds to the paper. And it enables the edge mask output by MLP to engage in the training process and be optimized.

@smiles724
Copy link
Author

Hi, thanks for your immediate response.

I agree that PYG provides a way to put weight on the message passing process. However, there are two additional issues that I am concerned about.

Firstly, the edge score outputted by the MLP takes part in calculating the propagation function. However, it is not explicitly used to select edges. In your code, you regard the rank of edge scores as the rank of edge importance. This lacks theoretical support but is more like an experimental trick. The graph encoder has already run two convolutional layers on the entire graph, and the message between each pair of nodes may be fully propagated (as you used the updated feature in the proceeding process). Thus, you may not say that a high edge score means the edge is more critical. The biggest problem is that since we do not have ground truth labels of explanatory subgraphs, we can not impose effective supervision on the selection of subgraphs' edges.

Second, it is cool to use the in-build __mask__ method to incorporate the edge score into the computation. However, I believe it is better to rewrite the messaging passing function. There are two reasons. First, it is hard to understand your design at first glance since you adopt the characteristics of PYG's explainer method. Second, it may be misleading since this __explain__ and __mask__ is specifically proposed for explaining the predictions of GNNs. If someone hopes to train and explain the model simultaneously, then it will lead to conflict between your model and the extra explainer models such as PGExplainer.

I liked your idea very much. It is a fancy idea to explore the subgraph first and make predictions on this substructure.

@Wuyxin
Copy link
Owner

Wuyxin commented Feb 22, 2023

Hi!

For the first question, actually, "the message between each pair of nodes may be fully propagated" is not correct. The edge score generated by the MLP has used to constraint the message passed to the end node. I copy the code to here again:
out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
If the edge score (after sigmoid) is zero, then there is no interaction between the node pairs, right? (since out=0, i.e, message=0) So we do use edge score explicitly to select edges.
Also, we do have ground truth in the Spurious-Motif data, and we compute the precision to validate this point.

For the second question, thanks for the suggestions! I will add more instructions in the README. But in fact, we are explaining the predictions of GNNs in some sense. It is just that instead of generating explanations in a post-hoc manner, we are doing it by building rationales in the internal of GNNs. We also compare the explanations of GNNExplainer and our generated rationales in the Appendix of our paper. So I would say there won't be conflicts (unless you want two kinds of explanations, and I am not sure if this is necessary).

Hope my answers help!!

@smiles724
Copy link
Author

Thanks for your immediate and very clear response.

Sure, the edge score is used to update the messages. However, what I mean by saying "the message between each pair of nodes may be fully propagated" is that in the CausalAttNet you use the complete graph topology to propagate messages between vertices (see conv1 and conv2 layers). The node representations you feed into the following two classifiers (casual classifier and short-cut classifier) are already from a global view rather than a local view. In my personal opinion, it is better to directly compute the edge score without performing any convolution based on the entire graph. Otherwise, the edge score is implicit to reflect the importance of edges.

Moreover, no intention to offend. Indeed, you compute the precision in the Spurious-Motif data, but the metric seems not high enough. Thus, I am still not convinced that the network really learns to select edges :)

@Wuyxin
Copy link
Owner

Wuyxin commented Feb 22, 2023

That's a good point! But I don't think computing the edge score without performing any convolution will work since the rationale generator need to capture neighborhood information. If you just use an MLP and then use the node features as inputs, a simple failure example that it won't possibly work is that considering the node features are the same, the edge score will be the same anyways. We discussed this in Appendix G.1.

I feel another gap between my and your understanding is that, so, a good generation of edge scores (in the training process) is not only depend on CausalAttNet, right? It also depends on the objective that pushes the boundary between causality and spuriousness. With or without the first two shared conv layers, the subgraph structure is supposed to conform to the DIR principle. But I agree that the shared layers could mix the information between causal and non-causal parts.

So actually, I think another way to avoid the problem you talked about is simply inputting the original graph to g, so that the node embeddings that are used for prediction will be generated inside of g, and the CausalAttNet will just output the edge score. Although my instinct is that it won't cause much difference since the messages are only propagated two hops, so there won't be much noise to interfere with the individual modeling of the causal and non-causal subgraphs.

For the last point, the question could be quite general honestly. But I will make it short here: think of a cat-dog classification model, if we ask what the features are most important for the model to classify a dog image, it is likely that the model won't highlight the whole dog but some small parts of it, and the importance of the rest pixels could be quite random. For the human, we look at the whole dog instead. (I feel I am making it long...) I mean, I am sure if you apply some subgraph matching algorithms on the dataset will get a far better precision score (because we know both the data distribution and how the subgraph matching works), but here, we are actually evaluating the model rationales using an approximated way that based on the human intuition. And human intuition is not always overlap with the model, even on synthetic data. So the thing we are doing is that, based on the precision metric, we see the model rationales are improved since it is heading in a direction that aligns with human intuition. But low precision doesn't necessarily mean the model is bad at selecting. It could be simply that we are using different reasoning mechanisms. So my point here is that it makes more sense to look at the relative improvement rather than the absolute value. But yes, I think there is a large room to improve our current results.

@Wuyxin
Copy link
Owner

Wuyxin commented Feb 22, 2023

Also I feel we are not talking about trainable parameters here, it would be better for us to close this issue and start another if you have further questions!

@Wuyxin Wuyxin changed the title How to update the trainable parameters for edge models? How to update the trainable parameters for edge models & How did you set the edge scores as masks in the message-passing? Feb 22, 2023
@smiles724
Copy link
Author

Oh, I was also aware of my mistake of directly using MLP to output edge scores after my writing.

I also believe it is better to use CausalAttNet just to procure edge scores and input the original graph to GNNs. Practically speaking, there might not be too much difference. However, the logics behind the computational process are significantly distinct from each other. Separating the causal part and the non-causal part is exactly one of the highlights of your paper. Mixing them together is confusing for readers to understand the importance of the causal part.

Don't worry about my "complaint" about your precision value. I constantly hold the view that excellent works are not only ones that completely solve the problem but also ones that give us great inspiration. As you said, your performance has already outweighed others. Some studies point out previously that DL models are not required to comprehend the whole picture to understand the object. So, you are right. We do not need to be too critical about the metric. (besides, I guess you read one of my papers, subgraph matching to explain GNNs :) thx )

Last but not least, I really enjoyed discussing with you and learned a lot. You are so a young but talented researcher. Thanks for sparing your time to answer this issue. Let me close it with respect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants