Skip to content

Commit

Permalink
Utilize HeteroData.set_value_dict in code base (#6974)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 20, 2023
1 parent 81f2daa commit b6ccfba
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 31 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added the `DistMult` KGE model ([#6958](https://github.com/pyg-team/pytorch_geometric/pull/6958))
- Added `HeteroData.set_value_dict` functionality ([#6961](https://github.com/pyg-team/pytorch_geometric/pull/6961))
- Added `HeteroData.set_value_dict` functionality ([#6961](https://github.com/pyg-team/pytorch_geometric/pull/6961), [#6974](https://github.com/pyg-team/pytorch_geometric/pull/6974))
- Added PyTorch >= 2.0 support ([#6934](https://github.com/pyg-team/pytorch_geometric/pull/6934))
- Added PyTorch Lightning >= 2.0 support ([#6929](https://github.com/pyg-team/pytorch_geometric/pull/6929))
- Added the `ComplEx` KGE model ([#6898](https://github.com/pyg-team/pytorch_geometric/pull/6898))
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def set_value_dict(
print(data['paper'].x)
"""
for k, v in value_dict.items():
for k, v in (value_dict or {}).items():
self[k][key] = v
return self

Expand Down
8 changes: 2 additions & 6 deletions torch_geometric/explain/algorithm/captum_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,8 @@ def forward(
return Explanation(node_mask=node_mask, edge_mask=edge_mask)

explanation = HeteroExplanation()
if node_mask is not None:
for node_type, mask in node_mask.items():
explanation.node_mask_dict[node_type] = mask
if edge_mask is not None:
for edge_type, mask in edge_mask.items():
explanation.edge_mask_dict[edge_type] = mask
explanation.set_value_dict('node_mask', node_mask)
explanation.set_value_dict('edge_mask', edge_mask)
return explanation

def supports(self) -> bool:
Expand Down
12 changes: 5 additions & 7 deletions torch_geometric/explain/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,23 +222,21 @@ def __call__(
explanation[key] = arg

elif isinstance(explanation, HeteroExplanation):
assert isinstance(x, dict)
# TODO Add `explanation._model_args`
for node_type, value in x.items():
explanation[node_type].x = value

assert isinstance(x, dict)
explanation.set_value_dict('x', x)

assert isinstance(edge_index, dict)
for edge_type, value in edge_index.items():
explanation[edge_type].edge_index = value
explanation.set_value_dict('edge_index', edge_index)

for key, arg in kwargs.items(): # Add remaining `kwargs`:
if isinstance(arg, dict):
# Keyword arguments are likely named `{attr_name}_dict`
# while we only want to assign the `{attr_name}` to the
# `HeteroExplanation` object:
key = key[:-5] if key.endswith('_dict') else key
for type_name, value in arg.items():
explanation[type_name][key] = value
explanation.set_value_dict(key, arg)
else:
explanation[key] = arg

Expand Down
13 changes: 5 additions & 8 deletions torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,11 @@ def filter_fn(
if 'n_id' not in data[key]:
data[key].n_id = node

if out.edge is not None:
for key, edge in out.edge.items():
if 'e_id' not in data[key]:
data[key].e_id = edge

if out.batch is not None:
for key, batch in out.batch.items():
data[key].batch = batch
for key, edge in (out.edge or {}).items():
if 'e_id' not in data[key]:
data[key].e_id = edge

data.set_value_dict('batch', out.batch)

input_type = self.input_data.input_type
data[input_type].input_id = out.metadata[0]
Expand Down
11 changes: 3 additions & 8 deletions torch_geometric/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,9 @@ def filter_fn(
if 'e_id' not in data[key]:
data[key].e_id = edge

for key, batch in (out.batch or {}).items():
data[key].batch = batch

for key, value in (out.num_sampled_nodes or {}).items():
data[key].num_sampled_nodes = value

for key, value in (out.num_sampled_edges or {}).items():
data[key].num_sampled_edges = value
data.set_value_dict('batch', out.batch)
data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes)
data.set_value_dict('num_sampled_edges', out.num_sampled_edges)

input_type = self.input_data.input_type
data[input_type].input_id = out.metadata[0]
Expand Down

0 comments on commit b6ccfba

Please sign in to comment.