-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
to_captum
: Support for explaining heterogenous graphs
#5934
Conversation
Codecov Report
@@ Coverage Diff @@
## master #5934 +/- ##
==========================================
+ Coverage 84.70% 84.77% +0.06%
==========================================
Files 360 360
Lines 19862 19962 +100
==========================================
+ Hits 16825 16923 +98
- Misses 3037 3039 +2
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the implementation! :D I have two points/suggestions:
-
I think currently the edge_mask is still the regular edge_index, but it needs to be a mask, i.e.
torch.ones((1, num_edges ))
. This needs to be adjusted throughout the file. -
Also, I am still in favor of a uniform implementation for hetero and non-hetero data. This way the functionality for the user would be the same regardless of the data type. Suggesting a method that converts the data and the model:
to_captum(data=None, model=None, ...)
->captum_input
,captum_additional_args
,CaptumModel
/CaptumHeteroModel
.
or divided into two functions:
-
to_captum_data(...)
->captum_input
,captum_additional_args
-
to_captum_model(...)
->CaptumModel
/CaptumHeteroModel
However, I would keep the input of these methods the same for hetero and non-hetero data and internally check the type of data and model.
@RBendias added a method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I left a few suggestions regardig the docs.
Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
to_captum
: Support for explaining heterogenous graphs.to_captum
: Support for explaining heterogenous graphs
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
This PR adds `CaptumHeteroModel` , `to_captum_input` and `captum_output_to_dicts` which can be used to explain heterogenous graphs as follows. ``` data: HeteroData = (...) model = ... # A heterogenous model mask_type = ... captum_model: CaptumHeteroModel = to_captum_model(model, mask_type, output_idx=output_idx, data.metadata) inputs, additonal_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type) ig = IntegratedGradients(captum_model) ig_attr_nodes_edges = ig.attribute(inputs, target=target, additional_forward_args=additonal_forward_args, internal_batch_size=1) x_attr_dict, edge_attr_dict = captum_output_to_dicts(ig_attr_nodes_edges, mask_type, data.metadata) ``` **TODOs in follow up PRs** 1. Add an example for `to_captum` with hetero data. 1. Move this behind the `ExplainerAlgorithm` interface being developed in pyg-team#5804. The interaace has to be extended to support `HeteroData`. Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charles Dufour <34485907+dufourc1@users.noreply.github.com> Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
This PR adds
CaptumHeteroModel
,to_captum_input
andcaptum_output_to_dicts
which can be used to explain heterogenous graphs as follows.TODOs in follow up PRs
to_captum
with hetero data.ExplainerAlgorithm
interface being developed in GNN explanation settings #5804. The interaace has to be extended to supportHeteroData
.