Skip to content

Commit

Permalink
to_captum: Support for explaining heterogenous graphs (#5934)
Browse files Browse the repository at this point in the history
This PR adds
`CaptumHeteroModel` , `to_captum_input` and `captum_output_to_dicts`
which can be used to explain heterogenous graphs as follows.
```
data: HeteroData = (...)
model = ... # A heterogenous model
mask_type = ...

captum_model: CaptumHeteroModel = to_captum_model(model, mask_type, output_idx=output_idx, data.metadata)
inputs, additonal_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type)
ig = IntegratedGradients(captum_model)
ig_attr_nodes_edges = ig.attribute(inputs, target=target,
    additional_forward_args=additonal_forward_args, internal_batch_size=1)
x_attr_dict, edge_attr_dict = captum_output_to_dicts(ig_attr_nodes_edges, mask_type, data.metadata)
```
**TODOs in follow up PRs**
1. Add an example for `to_captum` with hetero data.
1. Move this behind the `ExplainerAlgorithm` interface being developed
in #5804. The interaace has to be extended to support `HeteroData`.

Co-authored-by: Ramona Bendias <ramona.bendias@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charles Dufour <34485907+dufourc1@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
5 people authored Nov 22, 2022
1 parent 01037db commit a76d897
Show file tree
Hide file tree
Showing 6 changed files with 502 additions and 59 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934))
- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007))
- Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834))
- Added `LRGBDataset` to include 5 datasets from the [Long Range Graph Benchmark](https://openreview.net/pdf?id=in7XC5RcjEn) ([#5935](https://github.com/pyg-team/pytorch_geometric/pull/5935))
Expand Down
10 changes: 5 additions & 5 deletions examples/captum_explainability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Explainer, GCNConv, to_captum
from torch_geometric.nn import Explainer, GCNConv, to_captum_model

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
Expand Down Expand Up @@ -49,7 +49,7 @@ def forward(self, x, edge_index):

# Captum assumes that for all given input tensors, dimension 0 is
# equal to the number of samples. Therefore, we use unsqueeze(0).
captum_model = to_captum(model, mask_type='edge', output_idx=output_idx)
captum_model = to_captum_model(model, mask_type='edge', output_idx=output_idx)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
Expand All @@ -69,7 +69,7 @@ def forward(self, x, edge_index):
# Node explainability
# ===================

captum_model = to_captum(model, mask_type='node', output_idx=output_idx)
captum_model = to_captum_model(model, mask_type='node', output_idx=output_idx)

ig = IntegratedGradients(captum_model)
ig_attr_node = ig.attribute(data.x.unsqueeze(0), target=target,
Expand All @@ -88,8 +88,8 @@ def forward(self, x, edge_index):
# Node and edge explainability
# ============================

captum_model = to_captum(model, mask_type='node_and_edge',
output_idx=output_idx)
captum_model = to_captum_model(model, mask_type='node_and_edge',
output_idx=output_idx)

ig = IntegratedGradients(captum_model)
ig_attr_node, ig_attr_edge = ig.attribute(
Expand Down
137 changes: 128 additions & 9 deletions test/nn/models/test_hetero_explainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,38 @@
import pytest
import torch

from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, HeteroConv, SAGEConv, to_hetero
from torch_geometric.nn.models.explainer import clear_masks, set_hetero_masks
from torch_geometric.nn import (
GCNConv,
HeteroConv,
SAGEConv,
captum_output_to_dicts,
to_captum_input,
to_captum_model,
to_hetero,
)
from torch_geometric.nn.models.explainer import (
CaptumHeteroModel,
clear_masks,
set_hetero_masks,
)
from torch_geometric.testing import withPackage
from torch_geometric.typing import Metadata

mask_types = ['edge', 'node_and_edge', 'node']
methods = [
'Saliency',
'InputXGradient',
'Deconvolution',
'FeatureAblation',
'ShapleyValueSampling',
'IntegratedGradients',
'GradientShap',
'Occlusion',
'GuidedBackprop',
'KernelShap',
'Lime',
]


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
Expand All @@ -11,6 +41,16 @@ def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
return torch.stack([row, col], dim=0)


def get_hetero_data():
data = HeteroData()
data['paper'].x = torch.randn(8, 16)
data['author'].x = torch.randn(10, 8)
data['paper', 'paper'].edge_index = get_edge_index(8, 8, 10)
data['author', 'paper'].edge_index = get_edge_index(10, 8, 10)
data['paper', 'author'].edge_index = get_edge_index(8, 10, 10)
return data


class HeteroModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -45,14 +85,20 @@ def forward(self, x, edge_index):
return self.conv2(x, edge_index)


def test_set_clear_mask():
data = HeteroData()
data['paper'].x = torch.randn(50, 16)
data['author'].x = torch.randn(30, 8)
data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200)
data['author', 'paper'].edge_index = get_edge_index(30, 50, 100)
data['paper', 'author'].edge_index = get_edge_index(50, 30, 100)
class HeteroSAGE(torch.nn.Module):
def __init__(self, metadata: Metadata):
super().__init__()
self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False)

def forward(self, x_dict, edge_index_dict,
additonal_arg=None) -> torch.Tensor:
# Make sure additonal args gets passed.
assert additonal_arg is not None
return self.graph_sage(x_dict, edge_index_dict)['paper']


def test_set_clear_mask():
data = get_hetero_data()
edge_mask_dict = {
('paper', 'to', 'paper'): torch.ones(200),
('author', 'to', 'paper'): torch.ones(100),
Expand Down Expand Up @@ -89,3 +135,76 @@ def test_set_clear_mask():
str_edge_type = '__'.join(edge_type)
assert model.conv1[str_edge_type]._edge_mask is None
assert not model.conv1[str_edge_type].explain


@withPackage('captum')
@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('method', methods)
def test_captum_attribution_methods_hetero(mask_type, method):
from captum import attr # noqa
data = get_hetero_data()
metadata = data.metadata()
model = HeteroSAGE(metadata)
captum_model = to_captum_model(model, mask_type, 0, metadata)
explainer = getattr(attr, method)(captum_model)
assert isinstance(captum_model, CaptumHeteroModel)

args = ['additional_arg1']
input, additional_forward_args = to_captum_input(data.x_dict,
data.edge_index_dict,
mask_type, *args)
if mask_type == 'node':
sliding_window_shapes = ((3, 3), (3, 3))
elif mask_type == 'edge':
sliding_window_shapes = ((5, ), (5, ), (5, ))
else:
sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, ))

if method == 'IntegratedGradients':
attributions, delta = explainer.attribute(
input, target=0, internal_batch_size=1,
additional_forward_args=additional_forward_args,
return_convergence_delta=True)
elif method == 'GradientShap':
attributions, delta = explainer.attribute(
input, target=0, return_convergence_delta=True, baselines=input,
n_samples=1, additional_forward_args=additional_forward_args)
elif method == 'DeepLiftShap' or method == 'DeepLift':
attributions, delta = explainer.attribute(
input, target=0, return_convergence_delta=True, baselines=input,
additional_forward_args=additional_forward_args)
elif method == 'Occlusion':
attributions = explainer.attribute(
input, target=0, sliding_window_shapes=sliding_window_shapes,
additional_forward_args=additional_forward_args)
else:
attributions = explainer.attribute(
input, target=0, additional_forward_args=additional_forward_args)

if mask_type == 'node':
assert len(attributions) == len(metadata[0])
x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type,
metadata)
for node_type in metadata[0]:
num_nodes = data[node_type].num_nodes
num_node_feats = data[node_type].x.shape[1]
assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)
elif mask_type == 'edge':
assert len(attributions) == len(metadata[1])
_, edge_attr_dict = captum_output_to_dicts(attributions, mask_type,
metadata)
for edge_type in metadata[1]:
num_edges = data[edge_type].edge_index.shape[1]
assert edge_attr_dict[edge_type].shape == (num_edges, )
else:
assert len(attributions) == len(metadata[0]) + len(metadata[1])
x_attr_dict, edge_attr_dict = captum_output_to_dicts(
attributions, mask_type, metadata)
for edge_type in metadata[1]:
num_edges = data[edge_type].edge_index.shape[1]
assert edge_attr_dict[edge_type].shape == (num_edges, )

for node_type in metadata[0]:
num_nodes = data[node_type].num_nodes
num_node_feats = data[node_type].x.shape[1]
assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)
99 changes: 85 additions & 14 deletions test/nn/models/test_to_captum.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch

from torch_geometric.nn import GAT, GCN, Explainer, SAGEConv, to_captum
from torch_geometric.data import Data, HeteroData
from torch_geometric.nn import GAT, GCN, Explainer, SAGEConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.models import to_captum_input, to_captum_model
from torch_geometric.testing import withPackage

x = torch.randn(8, 3, requires_grad=True)
Expand Down Expand Up @@ -31,7 +33,8 @@
@pytest.mark.parametrize('model', [GCN, GAT])
@pytest.mark.parametrize('output_idx', [None, 1])
def test_to_captum(model, mask_type, output_idx):
captum_model = to_captum(model, mask_type=mask_type, output_idx=output_idx)
captum_model = to_captum_model(model, mask_type=mask_type,
output_idx=output_idx)
pre_out = model(x, edge_index)
if mask_type == 'node':
mask = x * 0.0
Expand Down Expand Up @@ -61,22 +64,16 @@ def test_to_captum(model, mask_type, output_idx):
def test_captum_attribution_methods(mask_type, method):
from captum import attr # noqa

captum_model = to_captum(GCN, mask_type, 0)
input_mask = torch.ones((1, edge_index.shape[1]), dtype=torch.float,
requires_grad=True)
captum_model = to_captum_model(GCN, mask_type, 0)
explainer = getattr(attr, method)(captum_model)

data = Data(x, edge_index)
input, additional_forward_args = to_captum_input(data.x, data.edge_index,
mask_type)
if mask_type == 'node':
input = x.clone().unsqueeze(0)
additional_forward_args = (edge_index, )
sliding_window_shapes = (3, 3)
elif mask_type == 'edge':
input = input_mask
additional_forward_args = (x, edge_index)
sliding_window_shapes = (5, )
elif mask_type == 'node_and_edge':
input = (x.clone().unsqueeze(0), input_mask)
additional_forward_args = (edge_index, )
sliding_window_shapes = ((3, 3), (5, ))

if method == 'IntegratedGradients':
Expand All @@ -100,9 +97,9 @@ def test_captum_attribution_methods(mask_type, method):
attributions = explainer.attribute(
input, target=0, additional_forward_args=additional_forward_args)
if mask_type == 'node':
assert attributions.shape == (1, 8, 3)
assert attributions[0].shape == (1, 8, 3)
elif mask_type == 'edge':
assert attributions.shape == (1, 14)
assert attributions[0].shape == (1, 14)
else:
assert attributions[0].shape == (1, 8, 3)
assert attributions[1].shape == (1, 14)
Expand Down Expand Up @@ -144,3 +141,77 @@ def explain_message(self, inputs, x_i, x_j):

assert torch.allclose(conv.x_i, x[edge_index[1]])
assert torch.allclose(conv.x_j, x[edge_index[0]])


@withPackage('captum')
@pytest.mark.parametrize('mask_type', ['node', 'edge', 'node_and_edge'])
def test_to_captum_input(mask_type):
num_nodes = x.shape[0]
num_node_feats = x.shape[1]
num_edges = edge_index.shape[1]

# Check for Data:
data = Data(x, edge_index)
args = 'test_args'
inputs, additional_forward_args = to_captum_input(data.x, data.edge_index,
mask_type, args)
if mask_type == 'node':
assert len(inputs) == 1
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert len(additional_forward_args) == 2
assert torch.allclose(additional_forward_args[0], edge_index)
elif mask_type == 'edge':
assert len(inputs) == 1
assert inputs[0].shape == (1, num_edges)
assert inputs[0].sum() == num_edges
assert len(additional_forward_args) == 3
assert torch.allclose(additional_forward_args[0], x)
assert torch.allclose(additional_forward_args[1], edge_index)
else:
assert len(inputs) == 2
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert inputs[1].shape == (1, num_edges)
assert inputs[1].sum() == num_edges
assert len(additional_forward_args) == 2
assert torch.allclose(additional_forward_args[0], edge_index)

# Check for HeteroData:
data = HeteroData()
x2 = torch.rand(8, 3)
data['paper'].x = x
data['author'].x = x2
data['paper', 'to', 'author'].edge_index = edge_index
data['author', 'to', 'paper'].edge_index = edge_index.flip([0])
inputs, additional_forward_args = to_captum_input(data.x_dict,
data.edge_index_dict,
mask_type, args)
if mask_type == 'node':
assert len(inputs) == 2
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert inputs[1].shape == (1, num_nodes, num_node_feats)
assert len(additional_forward_args) == 2
for key in data.edge_types:
torch.allclose(additional_forward_args[0][key],
data[key].edge_index)
elif mask_type == 'edge':
assert len(inputs) == 2
assert inputs[0].shape == (1, num_edges)
assert inputs[1].shape == (1, num_edges)
assert inputs[1].sum() == inputs[0].sum() == num_edges
assert len(additional_forward_args) == 3
for key in data.node_types:
torch.allclose(additional_forward_args[0][key], data[key].x)
for key in data.edge_types:
torch.allclose(additional_forward_args[1][key],
data[key].edge_index)
else:
assert len(inputs) == 4
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert inputs[1].shape == (1, num_nodes, num_node_feats)
assert inputs[2].shape == (1, num_edges)
assert inputs[3].shape == (1, num_edges)
assert inputs[3].sum() == inputs[2].sum() == num_edges
assert len(additional_forward_args) == 2
for key in data.edge_types:
torch.allclose(additional_forward_args[0][key],
data[key].edge_index)
6 changes: 5 additions & 1 deletion torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .graph_unet import GraphUNet
from .schnet import SchNet
from .dimenet import DimeNet, DimeNetPlusPlus
from .explainer import Explainer, to_captum
from .explainer import (Explainer, to_captum, to_captum_model, to_captum_input,
captum_output_to_dicts)
from .gnn_explainer import GNNExplainer
from .metapath2vec import MetaPath2Vec
from .deepgcn import DeepGCNLayer
Expand Down Expand Up @@ -47,6 +48,9 @@
'DimeNetPlusPlus',
'Explainer',
'to_captum',
'to_captum_model',
'to_captum_input',
'captum_output_to_dicts',
'GNNExplainer',
'MetaPath2Vec',
'DeepGCNLayer',
Expand Down
Loading

0 comments on commit a76d897

Please sign in to comment.