From bcd3a7cf7dc71b2483719534a4c6075d7b14fb82 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 30 Jul 2022 16:32:40 +0000 Subject: [PATCH 1/2] update --- test/nn/dense/test_linear.py | 4 ++++ torch_geometric/nn/dense/linear.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/nn/dense/test_linear.py b/test/nn/dense/test_linear.py index 4b2213c8c8c2..694764646ab6 100644 --- a/test/nn/dense/test_linear.py +++ b/test/nn/dense/test_linear.py @@ -47,6 +47,10 @@ def test_load_lazy_linear(dim1, dim2): assert hasattr(lin1, '_hook') assert hasattr(lin2, '_hook') + with pytest.raises(RuntimeError, match="in state_dict"): + lin1.load_state_dict({}, strict=True) + lin1.load_state_dict({}, strict=False) + @pytest.mark.parametrize('lazy', [True, False]) def test_identical_linear_default_initialization(lazy): diff --git a/torch_geometric/nn/dense/linear.py b/torch_geometric/nn/dense/linear.py index fc411d6e149b..078de939cbdd 100644 --- a/torch_geometric/nn/dense/linear.py +++ b/torch_geometric/nn/dense/linear.py @@ -137,15 +137,16 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - weight = state_dict[prefix + 'weight'] - if is_uninitialized_parameter(weight): + weight = state_dict.get(prefix + 'weight', None) + + if weight is not None and is_uninitialized_parameter(weight): self.in_channels = -1 self.weight = nn.parameter.UninitializedParameter() if not hasattr(self, '_hook'): self._hook = self.register_forward_pre_hook( self.initialize_parameters) - elif is_uninitialized_parameter(self.weight): + elif weight is not None and is_uninitialized_parameter(self.weight): self.in_channels = weight.size(-1) self.weight.materialize((self.out_channels, self.in_channels)) if hasattr(self, '_hook'): From 0174e626308380c739ed05584afb644cd50a2a3c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 30 Jul 2022 16:34:09 +0000 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2afe3dd170e2..fdc48a5c08d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed `load_state_dict` in `Linear` with `strict=False` mode ([5094](https://github.com/pyg-team/pytorch_geometric/pull/5094)) - Fixed typo in `MaskLabel.ratio_mask` ([5093](https://github.com/pyg-team/pytorch_geometric/pull/5093)) - Fixed `data.num_node_features` computation for sparse matrices ([5089](https://github.com/pyg-team/pytorch_geometric/pull/5089)) - Fixed `GenConv` test ([4993](https://github.com/pyg-team/pytorch_geometric/pull/4993))