Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding capabilities to HeteroData.to_homogeneous() #7374

Merged
merged 12 commits into from
May 17, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added padding capabilities to `HeteroData.to_homogeneous()` in case feature dimensionalities do not match ([#7374](https://github.com/pyg-team/pytorch_geometric/pull/7374))
- Added an optional `batch_size` argument to `fps`, `knn`, `knn_graph`, `radius` and `radius_graph` ([#7368](https://github.com/pyg-team/pytorch_geometric/pull/7368))
- Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383))
- Added an example for hierarichial sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))
Expand Down
16 changes: 16 additions & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,22 @@ def test_to_homogeneous_and_vice_versa():
assert out['author'].num_nodes == 200


def test_to_homogeneous_padding():
data = HeteroData()
data['paper'].x = torch.randn(100, 128)
data['author'].x = torch.randn(50, 64)

out = data.to_homogeneous()
assert len(out) == 2
assert out.node_type.size() == (150, )
assert out.node_type[:100].abs().sum() == 0
assert out.node_type[100:].sub(1).abs().sum() == 0
assert out.x.size() == (150, 128)
assert torch.equal(out.x[:100], data['paper'].x)
assert torch.equal(out.x[100:, :64], data['author'].x)
assert out.x[100:, 64:].abs().sum() == 0


def test_hetero_data_to_canonical():
data = HeteroData()
assert isinstance(data['user', 'product'], EdgeStorage)
Expand Down
29 changes: 25 additions & 4 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,20 @@ def fill_dummy_(stores: List[BaseStorage],

def _consistent_size(stores: List[BaseStorage]) -> List[str]:
sizes_dict = get_sizes(stores)
return [
key for key, sizes in sizes_dict.items()
if len(sizes) == len(stores) and len(set(sizes)) == 1
]
keys = []
for key, sizes in sizes_dict.items():
# The attribute needs to exist in all types:
if len(sizes) != len(stores):
continue
# The attributes needs to have the same number of dimensions:
lengths = set([len(size) for size in sizes])
if len(lengths) != 1:
continue
# The attributes needs to have the same size in all dimensions:
if len(sizes[0]) != 1 and len(set(sizes)) != 1:
continue
keys.append(key)
return keys

if dummy_values:
self = copy.copy(self)
Expand All @@ -855,6 +865,17 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
continue
values = [store[key] for store in self.node_stores]
dim = self.__cat_dim__(key, values[0], self.node_stores[0])
dim = values[0].dim() + dim if dim < 0 else dim
# For two-dimensional features, we allow arbitrary shapes and pad
# them with zeros if necessary in case their size doesn't match:
if values[0].dim() == 2 and dim == 0:
_max = max([value.size(-1) for value in values])
for i, v in enumerate(values):
if v.size(-1) < _max:
values[i] = torch.cat(
[v, v.new_zeros(v.size(0), _max - v.size(-1))],
dim=-1,
)
value = torch.cat(values, dim) if len(values) > 1 else values[0]
data[key] = value

Expand Down