Skip to content

Commit

Permalink
[BugFix] clear_refs_for_compile_() to clear weakrefs when compiling
Browse files Browse the repository at this point in the history
ghstack-source-id: ecbad083704930313dcfdd5bf3f6bb6b984030e8
Pull Request resolved: #1196
  • Loading branch information
vmoens committed Jan 30, 2025
1 parent 9a25b88 commit ba13541
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 9 deletions.
20 changes: 19 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11140,6 +11140,8 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
# During exit, updates mustn't be made in-place as the source and dest
# storage location can be identical, resulting in a RuntimeError
if is_compiling():
self.clear_refs_for_compile_()
if exc_type is not None and issubclass(exc_type, Exception):
return False
_last_op = self._last_op_queue.pop()
Expand All @@ -11149,11 +11151,27 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# added or deleted
_inv_caller = LAST_OP_MAPS.get(last_op)
if _inv_caller is not None:
return _inv_caller(self, args, kwargs, out_wr())
prev_ref = out_wr()
return _inv_caller(self, args, kwargs, prev_ref)
else:
raise NotImplementedError(f"Unrecognised function {last_op}.")
return self

def clear_refs_for_compile_(self) -> T:
"""Clears the weakrefs in order for the tensordict to get out of the compile region safely.

Use this whenever you hit `torch._dynamo.exc.Unsupported: reconstruct: WeakRefVariable()`
before returning a TensorDict.

Returns: self
"""
self._last_op = None
for v in self.values(True, True, is_leaf=_is_tensor_collection):
if _is_tensorclass(type(v)):
v = v._tensordict
v._last_op = None
return self

# Clone, select, exclude, empty
def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -> T:
"""Selects the keys of the tensordict and returns a new tensordict with only the selected keys.
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __subclasscheck__(self, subclass):
"batch_size",
"bytes",
"cat_tensors",
"clear_refs_for_compile_",
"data_ptr",
"depth",
"dim",
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ class TensorClass:
@device.setter
def device(self, value: DeviceType) -> torch.device | None: ...
def clear(self) -> T: ...
def clear_refs_for_compile_(self) -> T: ...
@classmethod
def fromkeys(cls, keys: list[NestedKey], value: Any = 0): ...
def popitem(self) -> tuple[NestedKey, CompatibleType]: ...
Expand Down
10 changes: 8 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,10 @@ def new_func(self, *args, **kwargs):
return new_func


def _strong_ref(self):
return lambda self: self


def _as_context_manager(attr=None):
"""Converts a method to a decorator.
Expand All @@ -1238,12 +1242,13 @@ def func_as_decorator(_self, *args, **kwargs):
_attr_post = getattr(_self, attr)
if out is not None:
if _attr_post is not _attr_pre:
ref = weakref.ref(_self)
out._last_op = (
func.__name__,
(
args,
kwargs,
weakref.ref(_self),
ref,
),
)
else:
Expand All @@ -1256,7 +1261,8 @@ def func_as_decorator(_self, *args, **kwargs):
def func_as_decorator(_self, *args, **kwargs):
out = func(_self, *args, **kwargs)
if out is not None:
out._last_op = (func.__name__, (args, kwargs, weakref.ref(_self)))
ref = weakref.ref(_self)
out._last_op = (func.__name__, (args, kwargs, ref))
return out

return func_as_decorator
Expand Down
15 changes: 9 additions & 6 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def reshape(td):

def test_view(self, mode):
def view(td):
return td.view(2, 2)
out = td.view(2, 2).clear_refs_for_compile_()
return out

view_c = torch.compile(view, fullgraph=True, mode=mode)
data = TensorDict({"a": {"b": torch.arange(4)}}, [4])
Expand All @@ -155,7 +156,7 @@ def view(td):

def test_transpose(self, mode):
def transpose(td):
return td.transpose(0, 1)
return td.transpose(0, 1).clear_refs_for_compile_()

transpose_c = torch.compile(transpose, fullgraph=True, mode=mode)
data = TensorDict({"a": {"b": torch.arange(6).view(2, 3)}}, [2, 3])
Expand Down Expand Up @@ -211,7 +212,7 @@ def clone(td: TensorDict):
@pytest.mark.parametrize("recurse", [True, False])
def test_flatten_keys(self, recurse, mode):
def flatten_keys(td: TensorDict):
return td.flatten_keys()
return td.flatten_keys().clear_refs_for_compile_()

flatten_keys_c = torch.compile(flatten_keys, fullgraph=True, mode=mode)
data = TensorDict({"a": {"b": 0, "c": 1}})
Expand All @@ -225,7 +226,7 @@ def flatten_keys(td: TensorDict):
@pytest.mark.parametrize("recurse", [True, False])
def test_unflatten_keys(self, recurse, mode):
def unflatten_keys(td: TensorDict):
return td.unflatten_keys()
return td.unflatten_keys().clear_refs_for_compile_()

unflatten_keys_c = torch.compile(unflatten_keys, fullgraph=True, mode=mode)
data = TensorDict({"a.b": 0, "a.c": 1})
Expand Down Expand Up @@ -280,7 +281,7 @@ def locked_op(td):
# Adding stuff uses cache, check that this doesn't break
td2 = td + 1
td3 = td + td2
return td3
return td3.clear_refs_for_compile_()

td = TensorDict(
{"a": torch.randn(1, 2, 3), "b": torch.zeros(1, 2, 3, dtype=torch.bool)},
Expand Down Expand Up @@ -761,7 +762,9 @@ def forward(self, x):

def call(x, td):
with td.to_module(module):
return module(x)
y = module(x)
td.clear_refs_for_compile_()
return y

call_compile = torch.compile(call, fullgraph=True, mode=mode)
x = torch.randn(2, 3)
Expand Down

0 comments on commit ba13541

Please sign in to comment.