diff --git a/tensordict/base.py b/tensordict/base.py index ffb2cac91..e64dfc1f6 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -11140,6 +11140,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): # During exit, updates mustn't be made in-place as the source and dest # storage location can be identical, resulting in a RuntimeError + if is_compiling(): + self.clear_refs_for_compile_() if exc_type is not None and issubclass(exc_type, Exception): return False _last_op = self._last_op_queue.pop() @@ -11149,11 +11151,27 @@ def __exit__(self, exc_type, exc_val, exc_tb): # added or deleted _inv_caller = LAST_OP_MAPS.get(last_op) if _inv_caller is not None: - return _inv_caller(self, args, kwargs, out_wr()) + prev_ref = out_wr() + return _inv_caller(self, args, kwargs, prev_ref) else: raise NotImplementedError(f"Unrecognised function {last_op}.") return self + def clear_refs_for_compile_(self) -> T: + """Clears the weakrefs in order for the tensordict to get out of the compile region safely. + + Use this whenever you hit `torch._dynamo.exc.Unsupported: reconstruct: WeakRefVariable()` + before returning a TensorDict. + + Returns: self + """ + self._last_op = None + for v in self.values(True, True, is_leaf=_is_tensor_collection): + if _is_tensorclass(type(v)): + v = v._tensordict + v._last_op = None + return self + # Clone, select, exclude, empty def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -> T: """Selects the keys of the tensordict and returns a new tensordict with only the selected keys. diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8a6f5cb3f..da8faee51 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -156,6 +156,7 @@ def __subclasscheck__(self, subclass): "batch_size", "bytes", "cat_tensors", + "clear_refs_for_compile_", "data_ptr", "depth", "dim", diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index f4131802d..a9a443ac2 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -436,6 +436,7 @@ class TensorClass: @device.setter def device(self, value: DeviceType) -> torch.device | None: ... def clear(self) -> T: ... + def clear_refs_for_compile_(self) -> T: ... @classmethod def fromkeys(cls, keys: list[NestedKey], value: Any = 0): ... def popitem(self) -> tuple[NestedKey, CompatibleType]: ... diff --git a/tensordict/utils.py b/tensordict/utils.py index dc2f0a769..bc7bffe45 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1217,6 +1217,10 @@ def new_func(self, *args, **kwargs): return new_func +def _strong_ref(self): + return lambda self: self + + def _as_context_manager(attr=None): """Converts a method to a decorator. @@ -1238,12 +1242,13 @@ def func_as_decorator(_self, *args, **kwargs): _attr_post = getattr(_self, attr) if out is not None: if _attr_post is not _attr_pre: + ref = weakref.ref(_self) out._last_op = ( func.__name__, ( args, kwargs, - weakref.ref(_self), + ref, ), ) else: @@ -1256,7 +1261,8 @@ def func_as_decorator(_self, *args, **kwargs): def func_as_decorator(_self, *args, **kwargs): out = func(_self, *args, **kwargs) if out is not None: - out._last_op = (func.__name__, (args, kwargs, weakref.ref(_self))) + ref = weakref.ref(_self) + out._last_op = (func.__name__, (args, kwargs, ref)) return out return func_as_decorator diff --git a/test/test_compile.py b/test/test_compile.py index bb07900a9..cbd70b1ee 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -144,7 +144,8 @@ def reshape(td): def test_view(self, mode): def view(td): - return td.view(2, 2) + out = td.view(2, 2).clear_refs_for_compile_() + return out view_c = torch.compile(view, fullgraph=True, mode=mode) data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) @@ -155,7 +156,7 @@ def view(td): def test_transpose(self, mode): def transpose(td): - return td.transpose(0, 1) + return td.transpose(0, 1).clear_refs_for_compile_() transpose_c = torch.compile(transpose, fullgraph=True, mode=mode) data = TensorDict({"a": {"b": torch.arange(6).view(2, 3)}}, [2, 3]) @@ -211,7 +212,7 @@ def clone(td: TensorDict): @pytest.mark.parametrize("recurse", [True, False]) def test_flatten_keys(self, recurse, mode): def flatten_keys(td: TensorDict): - return td.flatten_keys() + return td.flatten_keys().clear_refs_for_compile_() flatten_keys_c = torch.compile(flatten_keys, fullgraph=True, mode=mode) data = TensorDict({"a": {"b": 0, "c": 1}}) @@ -225,7 +226,7 @@ def flatten_keys(td: TensorDict): @pytest.mark.parametrize("recurse", [True, False]) def test_unflatten_keys(self, recurse, mode): def unflatten_keys(td: TensorDict): - return td.unflatten_keys() + return td.unflatten_keys().clear_refs_for_compile_() unflatten_keys_c = torch.compile(unflatten_keys, fullgraph=True, mode=mode) data = TensorDict({"a.b": 0, "a.c": 1}) @@ -280,7 +281,7 @@ def locked_op(td): # Adding stuff uses cache, check that this doesn't break td2 = td + 1 td3 = td + td2 - return td3 + return td3.clear_refs_for_compile_() td = TensorDict( {"a": torch.randn(1, 2, 3), "b": torch.zeros(1, 2, 3, dtype=torch.bool)}, @@ -761,7 +762,9 @@ def forward(self, x): def call(x, td): with td.to_module(module): - return module(x) + y = module(x) + td.clear_refs_for_compile_() + return y call_compile = torch.compile(call, fullgraph=True, mode=mode) x = torch.randn(2, 3)