Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 23, 2024
1 parent 3adde4b commit 173f3a2
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
import unittest
import warnings
import weakref

import pytest
import torch
Expand Down Expand Up @@ -3033,6 +3034,8 @@ def test_load_state_dict(self):
)

params = TensorDict.from_module(net, as_module=True)
assert any(isinstance(p, nn.Parameter) for p in params.values(True, True))
weakrefs = {weakref.ref(t) for t in params.values(True, True)}

# Now with a module around it
class MyModule(nn.Module):
Expand All @@ -3048,6 +3051,8 @@ class MyModule(nn.Module):
}
module.load_state_dict(sd)
assert (params == 0).all()
assert any(isinstance(p, nn.Parameter) for p in params.values(True, True))
assert weakrefs == {weakref.ref(t) for t in params.values(True, True)}

def test_inplace_ops(self):
td = TensorDict(
Expand Down

0 comments on commit 173f3a2

Please sign in to comment.