Skip to content

Commit

Permalink
[BugFix] Fix load_state_dict for TensorDictParams (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 26, 2024
1 parent f87d3fa commit 6661b5d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
23 changes: 11 additions & 12 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,18 +988,17 @@ def _load_from_state_dict(
unexpected_keys,
error_msgs,
):
data = (
TensorDict(
{
key: val
for key, val in state_dict.items()
if key.startswith(prefix) and val is not None
},
[],
)
.unflatten_keys(".")
.get(prefix[:-1])
)
data = TensorDict(
{
key: val
for key, val in state_dict.items()
if key.startswith(prefix) and val is not None
},
[],
).unflatten_keys(".")
prefix = tuple(key for key in prefix.split(".") if key)
if prefix:
data = data.get(prefix)
self.data.load_state_dict(data)

def items(
Expand Down
36 changes: 36 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
import unittest
import warnings
import weakref

import pytest
import torch
Expand Down Expand Up @@ -3018,6 +3019,41 @@ class MyModule(nn.Module):
assert (params_m == 0).all()
assert not (params_m == 0).all()

def test_load_state_dict(self):
net = nn.Sequential(
nn.Linear(2, 2),
nn.Sequential(
nn.Linear(2, 2),
nn.Dropout(),
nn.BatchNorm1d(2),
nn.Sequential(
nn.Tanh(),
nn.Linear(2, 2),
),
),
)

params = TensorDict.from_module(net, as_module=True)
assert any(isinstance(p, nn.Parameter) for p in params.values(True, True))
weakrefs = {weakref.ref(t) for t in params.values(True, True)}

# Now with a module around it
class MyModule(nn.Module):
pass

module = MyModule()
module.model = MyModule()
module.model.params = params
sd = module.state_dict()
sd = {
key: val * 0 if isinstance(val, torch.Tensor) else val
for key, val in sd.items()
}
module.load_state_dict(sd)
assert (params == 0).all()
assert any(isinstance(p, nn.Parameter) for p in params.values(True, True))
assert weakrefs == {weakref.ref(t) for t in params.values(True, True)}

def test_inplace_ops(self):
td = TensorDict(
{
Expand Down

0 comments on commit 6661b5d

Please sign in to comment.