Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New attribute self._load_hook in linear class since 2.03 will raise KeyError when executing load_state_dict fucntion #5080

Closed
DavdGao opened this issue Jul 28, 2022 · 1 comment · Fixed by #5094
Labels

Comments

@DavdGao
Copy link

DavdGao commented Jul 28, 2022

🐛 Describe the bug

In Pytorch, the function load_state_dict(state_dict, strict) allows empty dict state_dict=={} when strict is False.
However, from version 2.03 the linear class in torch_geometric.nn.dense.linear.py has a new attribute self._load_hook, and when we execute Linear(xxxx).load_state_dict({}, strict=False), the linear class will execute the self._lazy_load_hook function as follows

    def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict,
                        missing_keys, unexpected_keys, error_msgs):

        weight = state_dict[prefix + 'weight']
        if is_uninitialized_parameter(weight):
            self.in_channels = -1
            self.weight = nn.parameter.UninitializedParameter()
            if not hasattr(self, '_hook'):
                self._hook = self.register_forward_pre_hook(
                    self.initialize_parameters)

        elif is_uninitialized_parameter(self.weight):
            self.in_channels = weight.size(-1)
            self.weight.materialize((self.out_channels, self.in_channels))
            if hasattr(self, '_hook'):
                self._hook.remove()
                delattr(self, '_hook')

Since the state_dict is empty, the line weight = state_dict[prefix + 'weight'] will report KeyError.

Environment

  • PyG version:
  • PyTorch version:
  • OS:
  • Python version:
  • CUDA/cuDNN version:
  • How you installed PyTorch and PyG (conda, pip, source):
  • Any other relevant information (e.g., version of torch-scatter):
@rusty1s
Copy link
Member

rusty1s commented Jul 30, 2022

Thanks for reporting. I just pushed a fix :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants