Skip to content

Commit

Permalink
Fix Linear.load_state_dict(..., strict=False) (#5094)
Browse files Browse the repository at this point in the history
* update

* changelog
  • Loading branch information
rusty1s authored Jul 30, 2022
1 parent 75787ee commit 682efd7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed `load_state_dict` in `Linear` with `strict=False` mode ([5094](https://github.com/pyg-team/pytorch_geometric/pull/5094))
- Fixed typo in `MaskLabel.ratio_mask` ([5093](https://github.com/pyg-team/pytorch_geometric/pull/5093))
- Fixed `data.num_node_features` computation for sparse matrices ([5089](https://github.com/pyg-team/pytorch_geometric/pull/5089))
- Fixed `GenConv` test ([4993](https://github.com/pyg-team/pytorch_geometric/pull/4993))
Expand Down
4 changes: 4 additions & 0 deletions test/nn/dense/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def test_load_lazy_linear(dim1, dim2):
assert hasattr(lin1, '_hook')
assert hasattr(lin2, '_hook')

with pytest.raises(RuntimeError, match="in state_dict"):
lin1.load_state_dict({}, strict=True)
lin1.load_state_dict({}, strict=False)


@pytest.mark.parametrize('lazy', [True, False])
def test_identical_linear_default_initialization(lazy):
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):

weight = state_dict[prefix + 'weight']
if is_uninitialized_parameter(weight):
weight = state_dict.get(prefix + 'weight', None)

if weight is not None and is_uninitialized_parameter(weight):
self.in_channels = -1
self.weight = nn.parameter.UninitializedParameter()
if not hasattr(self, '_hook'):
self._hook = self.register_forward_pre_hook(
self.initialize_parameters)

elif is_uninitialized_parameter(self.weight):
elif weight is not None and is_uninitialized_parameter(self.weight):
self.in_channels = weight.size(-1)
self.weight.materialize((self.out_channels, self.in_channels))
if hasattr(self, '_hook'):
Expand Down

0 comments on commit 682efd7

Please sign in to comment.