diff --git a/test/test_nn.py b/test/test_nn.py index da1cd8ab3..37d5975b0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8,6 +8,7 @@ import pickle import unittest import warnings +import weakref import pytest import torch @@ -3033,6 +3034,8 @@ def test_load_state_dict(self): ) params = TensorDict.from_module(net, as_module=True) + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + weakrefs = {weakref.ref(t) for t in params.values(True, True)} # Now with a module around it class MyModule(nn.Module): @@ -3048,6 +3051,8 @@ class MyModule(nn.Module): } module.load_state_dict(sd) assert (params == 0).all() + assert any(isinstance(p, nn.Parameter) for p in params.values(True, True)) + assert weakrefs == {weakref.ref(t) for t in params.values(True, True)} def test_inplace_ops(self): td = TensorDict(