Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 23, 2024
1 parent 79b0c01 commit cf222ba
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
13 changes: 11 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,7 @@ def _to_module(
if not use_state_dict and isinstance(module, TensorDictBase):
if return_swap:
swap = module.copy()
with module.unlock_():
module.update(self)
module._param_td = getattr(self, "_param_td", self)
return swap
else:
module.update(self)
Expand Down Expand Up @@ -407,7 +406,17 @@ def convert_type(x, y):
continue
child = __dict__["_modules"][key]
local_out = memo.get(id(child), NO_DEFAULT)

if local_out is NO_DEFAULT:
# if isinstance(child, TensorDictBase):
# # then child is a TensorDictParams
# from tensordict.nn import TensorDictParams
#
# local_out = child
# if not isinstance(value, TensorDictParams):
# value = TensorDictParams(value, no_convert=True)
# __dict__["_modules"][key] = value
# else:
local_out = value._to_module(
child,
inplace=inplace,
Expand Down
18 changes: 14 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2972,7 +2972,8 @@ def test_tdparams_clone_tying(self):
td_clone = td.clone()
assert td_clone["c"] is td_clone["a", "b", "c"]

def test_func_on_tdparams(self):
@pytest.mark.parametrize("with_batch", [False, True])
def test_func_on_tdparams(self, with_batch):
# tdparams isn't represented in a nested way, so we must check that calling to_module on it works ok
net = nn.Sequential(
nn.Linear(2, 2),
Expand All @@ -2987,9 +2988,13 @@ def test_func_on_tdparams(self):
),
)

params = TensorDict.from_module(net, as_module=True)
if with_batch:
params = TensorDict.from_modules(net, net, as_module=True)
params0 = params[0].expand(3).clone().apply(lambda x: x.data * 0)
else:
params = TensorDict.from_module(net, as_module=True)
params0 = params.apply(lambda x: x.data * 0)

params0 = params.apply(lambda x: x.data * 0)
assert (params0 == 0).all()
with params0.to_module(params):
assert (params == 0).all()
Expand All @@ -3002,7 +3007,12 @@ class MyModule(nn.Module):
m = MyModule()
m.params = params
params_m = TensorDict.from_module(m, as_module=True)
params_m0 = params_m.apply(lambda x: x.data * 0)
if with_batch:
params_m0 = params_m.clone()
params_m0["params"] = params_m0["params"][0].expand(3).clone()
else:
params_m0 = params_m
params_m0 = params_m0.apply(lambda x: x.data * 0)
assert (params_m0 == 0).all()
with params_m0.to_module(m):
assert (params_m == 0).all()
Expand Down

0 comments on commit cf222ba

Please sign in to comment.