From f709376d2ed5f0fd4fe9482b4850f5cd3d656691 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 4 Oct 2022 22:00:31 -0700 Subject: [PATCH] Fix `RGCN+pyg-lib` for `LongTensor` input (#5610) * Fix RGCN * changelog --- CHANGELOG.md | 1 + examples/rgcn.py | 2 +- torch_geometric/nn/conv/rgcn_conv.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c4808cb9e7d..1cb3b145da50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) - Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530)) diff --git a/examples/rgcn.py b/examples/rgcn.py index e2eb4669faee..00c738cc7fc7 100644 --- a/examples/rgcn.py +++ b/examples/rgcn.py @@ -9,7 +9,7 @@ from torch_geometric.utils import k_hop_subgraph parser = argparse.ArgumentParser() -parser.add_argument('--dataset', type=str, +parser.add_argument('--dataset', type=str, default='AIFB', choices=['AIFB', 'MUTAG', 'BGS', 'AM']) args = parser.parse_args() diff --git a/torch_geometric/nn/conv/rgcn_conv.py b/torch_geometric/nn/conv/rgcn_conv.py index 8eedbbfaf417..7e5e85af6813 100644 --- a/torch_geometric/nn/conv/rgcn_conv.py +++ b/torch_geometric/nn/conv/rgcn_conv.py @@ -225,7 +225,9 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], out = out + h.contiguous().view(-1, self.out_channels) else: # No regularization/Basis-decomposition ======================== - if self._WITH_PYG_LIB and isinstance(edge_index, Tensor): + if (self._WITH_PYG_LIB and self.num_bases is None + and x_l.is_floating_point() + and isinstance(edge_index, Tensor)): if not self.is_sorted: if (edge_type[1:] < edge_type[:-1]).any(): edge_type, perm = edge_type.sort()