Skip to content

Commit

Permalink
Add padding capabilities to HeteroData.to_homogeneous() (#7374)
Browse files Browse the repository at this point in the history
useful for putting randomly generated FakeHeteroDataset data into a
fused GNN like RGCNConv, otherwise the fakeheterodataset usually has x's
w/ diff num of features and the resulting Data would not have any node
features w/o this PR

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored May 17, 2023
1 parent 2e64b0f commit ce84dd9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added padding capabilities to `HeteroData.to_homogeneous()` in case feature dimensionalities do not match ([#7374](https://github.com/pyg-team/pytorch_geometric/pull/7374))
- Added an optional `batch_size` argument to `fps`, `knn`, `knn_graph`, `radius` and `radius_graph` ([#7368](https://github.com/pyg-team/pytorch_geometric/pull/7368))
- Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383))
- Added an example for hierarichial sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))
Expand Down
16 changes: 16 additions & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,22 @@ def test_to_homogeneous_and_vice_versa():
assert out['author'].num_nodes == 200


def test_to_homogeneous_padding():
data = HeteroData()
data['paper'].x = torch.randn(100, 128)
data['author'].x = torch.randn(50, 64)

out = data.to_homogeneous()
assert len(out) == 2
assert out.node_type.size() == (150, )
assert out.node_type[:100].abs().sum() == 0
assert out.node_type[100:].sub(1).abs().sum() == 0
assert out.x.size() == (150, 128)
assert torch.equal(out.x[:100], data['paper'].x)
assert torch.equal(out.x[100:, :64], data['author'].x)
assert out.x[100:, 64:].abs().sum() == 0


def test_hetero_data_to_canonical():
data = HeteroData()
assert isinstance(data['user', 'product'], EdgeStorage)
Expand Down
29 changes: 25 additions & 4 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,20 @@ def fill_dummy_(stores: List[BaseStorage],

def _consistent_size(stores: List[BaseStorage]) -> List[str]:
sizes_dict = get_sizes(stores)
return [
key for key, sizes in sizes_dict.items()
if len(sizes) == len(stores) and len(set(sizes)) == 1
]
keys = []
for key, sizes in sizes_dict.items():
# The attribute needs to exist in all types:
if len(sizes) != len(stores):
continue
# The attributes needs to have the same number of dimensions:
lengths = set([len(size) for size in sizes])
if len(lengths) != 1:
continue
# The attributes needs to have the same size in all dimensions:
if len(sizes[0]) != 1 and len(set(sizes)) != 1:
continue
keys.append(key)
return keys

if dummy_values:
self = copy.copy(self)
Expand All @@ -855,6 +865,17 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
continue
values = [store[key] for store in self.node_stores]
dim = self.__cat_dim__(key, values[0], self.node_stores[0])
dim = values[0].dim() + dim if dim < 0 else dim
# For two-dimensional features, we allow arbitrary shapes and pad
# them with zeros if necessary in case their size doesn't match:
if values[0].dim() == 2 and dim == 0:
_max = max([value.size(-1) for value in values])
for i, v in enumerate(values):
if v.size(-1) < _max:
values[i] = torch.cat(
[v, v.new_zeros(v.size(0), _max - v.size(-1))],
dim=-1,
)
value = torch.cat(values, dim) if len(values) > 1 else values[0]
data[key] = value

Expand Down

0 comments on commit ce84dd9

Please sign in to comment.