From 95e0f12b866fa9e40208250c51505f44fefe0afc Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Feb 2024 15:32:29 -0800 Subject: [PATCH 1/3] init --- tensordict/nn/params.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f51f1a1ca..81a124c81 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -990,18 +990,17 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): - data = ( - TensorDict( - { - key: val - for key, val in state_dict.items() - if key.startswith(prefix) and val is not None - }, - [], - ) - .unflatten_keys(".") - .get(prefix[:-1]) - ) + data = TensorDict( + { + key: val + for key, val in state_dict.items() + if key.startswith(prefix) and val is not None + }, + [], + ).unflatten_keys(".") + prefix = tuple(key for key in prefix.split(".") if key) + if prefix: + data = data.get(prefix) self.data.load_state_dict(data) def items( From 3adde4bc3a2727bd4dbfa7f7270f4fc13e86f854 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Feb 2024 15:39:12 -0800 Subject: [PATCH 2/3] init --- test/test_nn.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index 69c7b2fcf..da1cd8ab3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3018,6 +3018,37 @@ class MyModule(nn.Module): assert (params_m == 0).all() assert not (params_m == 0).all() + def test_load_state_dict(self): + net = nn.Sequential( + nn.Linear(2, 2), + nn.Sequential( + nn.Linear(2, 2), + nn.Dropout(), + nn.BatchNorm1d(2), + nn.Sequential( + nn.Tanh(), + nn.Linear(2, 2), + ), + ), + ) + + params = TensorDict.from_module(net, as_module=True) + + # Now with a module around it + class MyModule(nn.Module): + pass + + module = MyModule() + module.model = MyModule() + module.model.params = params + sd = module.state_dict() + sd = { + key: val * 0 if isinstance(val, torch.Tensor) else val + for key, val in sd.items() + } + module.load_state_dict(sd) + assert (params == 0).all() + def test_inplace_ops(self): td = TensorDict( { From 173f3a2d9ffe45fb6bd2dc4d808920a19d2052c4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Feb 2024 15:46:02 -0800 Subject: [PATCH 3/3] init --- test/test_nn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index da1cd8ab3..37d5975b0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8,6 +8,7 @@ import pickle import unittest import warnings +import weakref import pytest import torch @@ -3033,6 +3034,8 @@ def test_load_state_dict(self): ) params = TensorDict.from_module(net, as_module=True) + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + weakrefs = {weakref.ref(t) for t in params.values(True, True)} # Now with a module around it class MyModule(nn.Module): @@ -3048,6 +3051,8 @@ class MyModule(nn.Module): } module.load_state_dict(sd) assert (params == 0).all() + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + assert weakrefs == {weakref.ref(t) for t in params.values(True, True)} def test_inplace_ops(self): td = TensorDict(