Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_captum: Support for explaining heterogenous graphs #5934

Merged
merged 60 commits into from
Nov 22, 2022
Merged

Conversation

wsad1
Copy link
Member

@wsad1 wsad1 commented Nov 9, 2022

This PR adds
CaptumHeteroModel , to_captum_input and captum_output_to_dicts which can be used to explain heterogenous graphs as follows.

data: HeteroData = (...)
model = ... # A heterogenous model
mask_type = ...

captum_model: CaptumHeteroModel = to_captum_model(model, mask_type, output_idx=output_idx, data.metadata)
inputs, additonal_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type)
ig = IntegratedGradients(captum_model)
ig_attr_nodes_edges = ig.attribute(inputs, target=target,
    additional_forward_args=additonal_forward_args, internal_batch_size=1)
x_attr_dict, edge_attr_dict = captum_output_to_dicts(ig_attr_nodes_edges, mask_type, data.metadata)

TODOs in follow up PRs

  1. Add an example for to_captum with hetero data.
  2. Move this behind the ExplainerAlgorithm interface being developed in GNN explanation settings  #5804. The interaace has to be extended to support HeteroData.

@wsad1 wsad1 requested a review from RBendias November 9, 2022 06:21
@github-actions github-actions bot added the nn label Nov 9, 2022
@wsad1 wsad1 requested review from RexYing and rusty1s November 9, 2022 06:21
@wsad1 wsad1 self-assigned this Nov 9, 2022
@github-actions github-actions bot added the nn label Nov 9, 2022
@codecov
Copy link

codecov bot commented Nov 9, 2022

Codecov Report

Merging #5934 (b342bcc) into master (01037db) will increase coverage by 0.06%.
The diff coverage is 98.09%.

@@            Coverage Diff             @@
##           master    #5934      +/-   ##
==========================================
+ Coverage   84.70%   84.77%   +0.06%     
==========================================
  Files         360      360              
  Lines       19862    19962     +100     
==========================================
+ Hits        16825    16923      +98     
- Misses       3037     3039       +2     
Impacted Files Coverage Δ
torch_geometric/nn/models/explainer.py 96.83% <98.07%> (+0.75%) ⬆️
torch_geometric/nn/models/__init__.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Contributor

@RBendias RBendias left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the implementation! :D I have two points/suggestions:

  1. I think currently the edge_mask is still the regular edge_index, but it needs to be a mask, i.e. torch.ones((1, num_edges )). This needs to be adjusted throughout the file.

  2. Also, I am still in favor of a uniform implementation for hetero and non-hetero data. This way the functionality for the user would be the same regardless of the data type. Suggesting a method that converts the data and the model:

    • to_captum(data=None, model=None, ...) -> captum_input, captum_additional_args, CaptumModel / CaptumHeteroModel.

    or divided into two functions:

    • to_captum_data(...) -> captum_input, captum_additional_args

    • to_captum_model(...) -> CaptumModel / CaptumHeteroModel

    However, I would keep the input of these methods the same for hetero and non-hetero data and internally check the type of data and model.

torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
@wsad1
Copy link
Member Author

wsad1 commented Nov 10, 2022

@RBendias added a method to_captum_input which returns the input and additonal_forward_args as you suggested. Will add tests in a while.

@wsad1 wsad1 requested a review from RBendias November 10, 2022 08:07
Copy link
Contributor

@RBendias RBendias left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left a few suggestions regardig the docs.

torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
wsad1 and others added 6 commits November 17, 2022 18:43
Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
@rusty1s rusty1s changed the title to_captum: Support for explaining heterogenous graphs. to_captum: Support for explaining heterogenous graphs Nov 18, 2022
test/nn/models/test_hetero_explainer.py Outdated Show resolved Hide resolved
test/nn/models/test_hetero_explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
torch_geometric/nn/models/explainer.py Outdated Show resolved Hide resolved
wsad1 and others added 10 commits November 19, 2022 10:24
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
@wsad1 wsad1 requested a review from rusty1s November 19, 2022 07:03
@wsad1 wsad1 merged commit a76d897 into master Nov 22, 2022
@wsad1 wsad1 deleted the captum_hetero branch November 22, 2022 04:59
JakubPietrakIntel pushed a commit to JakubPietrakIntel/pytorch_geometric that referenced this pull request Nov 25, 2022
This PR adds
`CaptumHeteroModel` , `to_captum_input` and `captum_output_to_dicts`
which can be used to explain heterogenous graphs as follows.
```
data: HeteroData = (...)
model = ... # A heterogenous model
mask_type = ...

captum_model: CaptumHeteroModel = to_captum_model(model, mask_type, output_idx=output_idx, data.metadata)
inputs, additonal_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type)
ig = IntegratedGradients(captum_model)
ig_attr_nodes_edges = ig.attribute(inputs, target=target,
    additional_forward_args=additonal_forward_args, internal_batch_size=1)
x_attr_dict, edge_attr_dict = captum_output_to_dicts(ig_attr_nodes_edges, mask_type, data.metadata)
```
**TODOs in follow up PRs**
1. Add an example for `to_captum` with hetero data.
1. Move this behind the `ExplainerAlgorithm` interface being developed
in pyg-team#5804. The interaace has to be extended to support `HeteroData`.

Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charles Dufour <34485907+dufourc1@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants