Skip to content

Commit

Permalink
Fix RGCN+pyg-lib for LongTensor input (#5610)
Browse files Browse the repository at this point in the history
* Fix RGCN

* changelog
  • Loading branch information
rusty1s authored Oct 5, 2022
1 parent 292e289 commit f709376
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530))
Expand Down
2 changes: 1 addition & 1 deletion examples/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.utils import k_hop_subgraph

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str,
parser.add_argument('--dataset', type=str, default='AIFB',
choices=['AIFB', 'MUTAG', 'BGS', 'AM'])
args = parser.parse_args()

Expand Down
4 changes: 3 additions & 1 deletion torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
out = out + h.contiguous().view(-1, self.out_channels)

else: # No regularization/Basis-decomposition ========================
if self._WITH_PYG_LIB and isinstance(edge_index, Tensor):
if (self._WITH_PYG_LIB and self.num_bases is None
and x_l.is_floating_point()
and isinstance(edge_index, Tensor)):
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = edge_type.sort()
Expand Down

0 comments on commit f709376

Please sign in to comment.