diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 833c6405f..9d91ca0ba 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -988,18 +988,17 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): - data = ( - TensorDict( - { - key: val - for key, val in state_dict.items() - if key.startswith(prefix) and val is not None - }, - [], - ) - .unflatten_keys(".") - .get(prefix[:-1]) - ) + data = TensorDict( + { + key: val + for key, val in state_dict.items() + if key.startswith(prefix) and val is not None + }, + [], + ).unflatten_keys(".") + prefix = tuple(key for key in prefix.split(".") if key) + if prefix: + data = data.get(prefix) self.data.load_state_dict(data) def items( diff --git a/test/test_nn.py b/test/test_nn.py index 69c7b2fcf..37d5975b0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8,6 +8,7 @@ import pickle import unittest import warnings +import weakref import pytest import torch @@ -3018,6 +3019,41 @@ class MyModule(nn.Module): assert (params_m == 0).all() assert not (params_m == 0).all() + def test_load_state_dict(self): + net = nn.Sequential( + nn.Linear(2, 2), + nn.Sequential( + nn.Linear(2, 2), + nn.Dropout(), + nn.BatchNorm1d(2), + nn.Sequential( + nn.Tanh(), + nn.Linear(2, 2), + ), + ), + ) + + params = TensorDict.from_module(net, as_module=True) + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + weakrefs = {weakref.ref(t) for t in params.values(True, True)} + + # Now with a module around it + class MyModule(nn.Module): + pass + + module = MyModule() + module.model = MyModule() + module.model.params = params + sd = module.state_dict() + sd = { + key: val * 0 if isinstance(val, torch.Tensor) else val + for key, val in sd.items() + } + module.load_state_dict(sd) + assert (params == 0).all() + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + assert weakrefs == {weakref.ref(t) for t in params.values(True, True)} + def test_inplace_ops(self): td = TensorDict( {