From 785cac107bca0228d109dce6d8d5ca63300eaf81 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Sep 2024 08:31:56 +0100 Subject: [PATCH] [BugFix] dynamo compat refactors ghstack-source-id: 7681ecf34d26be5d50f780fe8dd98e1dc7974ec8 Pull Request resolved: https://github.com/pytorch/tensordict/pull/975 --- benchmarks/compile/tensordict_nn_test.py | 77 ++++++++++ tensordict/__init__.py | 4 +- tensordict/_contextlib.py | 181 +++++++++++++++++++++++ tensordict/_td.py | 12 +- tensordict/base.py | 139 ++--------------- tensordict/nn/params.py | 23 ++- tensordict/nn/probabilistic.py | 29 ++-- tensordict/nn/utils.py | 7 +- tensordict/utils.py | 82 +++++----- test/test_compile.py | 31 +--- test/test_nn.py | 6 +- test/test_tensordict.py | 2 +- 12 files changed, 371 insertions(+), 222 deletions(-) diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py index cf4c71567..40571a31d 100644 --- a/benchmarks/compile/tensordict_nn_test.py +++ b/benchmarks/compile/tensordict_nn_test.py @@ -329,6 +329,83 @@ def call_with_backward(*args): benchmark(call_with_backward, x) +@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) +def test_vmap_func_call_cm_runtime(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + module = mlp(device=device, depth=10, num_cells=16, feature_dim=16) + # module = torch.nn.Transformer(16, dim_feedforward=64, device=device) + td = TensorDict.from_module(module) + td = TensorDictParams(td.data.expand(10).clone().zero_()) + + def call(x, td): + # with needs registering + with td.to_module(module): + return module(x) + + call_vmap = torch.vmap(call, (None, 0)) + if mode == "compile": + call_vmap = torch.compile(call_vmap) + elif mode == "compile-overhead": + call_vmap = torch.compile(call_vmap, mode="reduce-overhead") + + x = torch.randn(2, 2, 16) + call_vmap(x, td) + call_vmap(x, td) + benchmark(call_vmap, x, td) + + +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) +@pytest.mark.parametrize("plain_decorator", [None, False, True]) +def test_vmap_func_call_runtime_and_backward(mode, plain_decorator, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + module = mlp(device=device, depth=10, num_cells=16, feature_dim=16) + # module = torch.nn.Transformer(16, dim_feedforward=64, device=device) + td = TensorDict.from_module(module) + td = TensorDictParams(td.data.expand(10).clone().zero_()) + if not plain_decorator: + + def call(x, td): + if torch.cuda.is_available(): + torch.compiler.cudagraph_mark_step_begin() + # with needs registering + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=False) + return result + + else: + + def call(x, td): + if torch.cuda.is_available(): + torch.compiler.cudagraph_mark_step_begin() + # with needs registering + with td.to_module(module): + return module(x) + + call_vmap = torch.vmap(call, (None, 0)) + if mode == "compile": + call_vmap = torch.compile(call_vmap) + elif mode == "compile-overhead": + call_vmap = torch.compile(call_vmap, mode="reduce-overhead") + + if mode == "compile": + call_vmap = torch.compile(call_vmap, fullgraph=not plain_decorator) + elif mode == "compile-overhead": + call_vmap = torch.compile( + call_vmap, fullgraph=not plain_decorator, mode="reduce-overhead" + ) + + def call_with_backward(*args): + call_vmap(*args).mean().backward() + + x = torch.randn(2, 2, 16) + call_with_backward(x, td) + call_with_backward(x, td) + benchmark(call_with_backward, x, td) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index a340d1276..19ea17e82 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -29,12 +29,10 @@ lazy_legacy, NestedKey, set_lazy_legacy, -) -from tensordict._pytree import * -from tensordict._C import ( # @manual=//pytorch/tensordict:_C unravel_key, unravel_key_list, ) +from tensordict._pytree import * from tensordict.nn import TensorDictParams try: diff --git a/tensordict/_contextlib.py b/tensordict/_contextlib.py index 004a2d09a..1ebf5f64c 100644 --- a/tensordict/_contextlib.py +++ b/tensordict/_contextlib.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib # This is a copy from https://github.com/pytorch/pytorch/blob/main/torch/utils/_contextlib.py#L120 # We use it for compatibility with torch >= 1.10 where the implementation fails @@ -16,6 +17,10 @@ import warnings from typing import Any, Callable, cast, TypeVar +import numpy as np +from torch.compiler import is_dynamo_compiling + + # Used for annotating the decorator usage of _DecoratorContextManager (e.g., # 'no_grad' and 'enable_grad'). # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators @@ -155,3 +160,179 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def clone(self): # override this method if your children class takes __init__ parameters return type(self)() + + +# TD cm functions +LAST_OP_MAPS = {} + + +def _reverse_lock(self, args, kwargs, out): + return self.unlock_() + + +LAST_OP_MAPS["lock_"] = _reverse_lock + + +def _reverse_unlock(self, args, kwargs, out): + return self.lock_() + + +LAST_OP_MAPS["unlock_"] = _reverse_unlock + + +def _reverse_transpose(self, args, kwargs, out): + dim0, dim1 = args + if not out.is_locked: + return out.update(self.transpose(dim0, dim1), inplace=False) + else: + return out.update_(self.transpose(dim0, dim1)) + + +LAST_OP_MAPS["transpose"] = _reverse_transpose + + +def _reverse_flatten_keys(self, args, kwargs, out): + sep = args[0] if args else "." + if not out.is_locked: + return out.update(self.unflatten_keys(sep), inplace=False) + else: + return out.update_(self.unflatten_keys(sep)) + + +LAST_OP_MAPS["flatten_keys"] = _reverse_flatten_keys + + +def _reverse_unflatten_keys(self, args, kwargs, out): + sep = args[0] if args else "." + if not out.is_locked: + return out.update(self.flatten_keys(sep), inplace=False) + else: + return out.update_(self.flatten_keys(sep)) + + +LAST_OP_MAPS["unflatten_keys"] = _reverse_unflatten_keys + + +def _reverse_flatten(self, args, kwargs, out): + if len(args) == 2: + dim0, dim1 = args + elif len(args) == 1: + dim0 = args[0] + dim1 = kwargs.get("end_dim", -1) + else: + dim0 = kwargs.get("start_dim", 0) + dim1 = kwargs.get("end_dim", -1) + if dim1 < 0: + dim1 = out.ndim + dim1 + if dim0 < 0: + dim0 = out.ndim + dim0 + + if not out.is_locked: + return out.update( + self.unflatten(dim0, out.shape[dim0 : dim1 + 1]), inplace=False + ) + else: + return out.update_(self.unflatten(dim0, out.shape[dim0 : dim1 + 1])) + + +LAST_OP_MAPS["flatten"] = _reverse_flatten + + +def _reverse_unflatten(self, args, kwargs, out): + if args: + dim0 = args[0] + if len(args) > 1: + unflattened_size = args[1] + else: + unflattened_size = kwargs.get("unflattened_size") + else: + dim0 = kwargs.get("dim") + unflattened_size = kwargs.get("unflattened_size") + if dim0 < 0: + dim0 = out.ndim + dim0 + dim1 = dim0 + len(unflattened_size) - 1 + if not out.is_locked: + unflattened = self.flatten(dim0, dim1) + return out.update(unflattened, inplace=False) + else: + unflattened = self.flatten(dim0, dim1) + return out.update_(unflattened) + + +LAST_OP_MAPS["unflatten"] = _reverse_unflatten + + +def _reverse_permute(self, args, kwargs, out): + from tensordict.utils import _get_shape_from_args + + dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs) + dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list] + # inverse map + inv_dims_list = np.argsort(dims_list) + if not out.is_locked: + return out.update(self.permute(inv_dims_list), inplace=False) + else: + return out.update_(self.permute(inv_dims_list)) + + +LAST_OP_MAPS["permute"] = _reverse_permute + + +def _reverse_view(self, args, kwargs, out): + if not out.is_locked: + return out.update(self.view(out.shape), inplace=False) + else: + return out.update_(self.view(out.shape)) + + +LAST_OP_MAPS["view"] = _reverse_view + + +def _reverse_unsqueeze(self, args, kwargs, out): + if args: + (dim,) = args + elif kwargs: + dim = kwargs["dim"] + else: + raise RuntimeError( + "Cannot use td.unsqueeze() as a decorator if the dimension is implicit." + ) + if not out.is_locked: + return out.update(self.squeeze(dim), inplace=False) + else: + return out.update_(self.squeeze(dim)) + + +LAST_OP_MAPS["unsqueeze"] = _reverse_unsqueeze + + +def _reverse_squeeze(self, args, kwargs, out): + if args: + (dim,) = args + elif kwargs: + dim = kwargs["dim"] + else: + raise RuntimeError( + "Cannot use td.squeeze() as a decorator if the dimension is implicit." + ) + if not out.is_locked: + return out.update(self.unsqueeze(dim), inplace=False) + else: + return out.update_(self.unsqueeze(dim)) + + +LAST_OP_MAPS["squeeze"] = _reverse_squeeze + + +def _reverse_to_module(self, args, kwargs, out): + try: + with out.unlock_() if not is_dynamo_compiling() else contextlib.nullcontext(): + return self.to_module(*args, **kwargs, swap_dest=out) + except AttributeError: + # This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD + raise RuntimeError( + "to_module cannot be used as a decorator when return_swap=False." + ) + + +LAST_OP_MAPS["to_module"] = _reverse_to_module diff --git a/tensordict/_td.py b/tensordict/_td.py index 0914461ed..ef49e4b89 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -68,7 +68,6 @@ _sub_index, _unravel_key_to_tuple, _zip_strict, - Buffer, cache, convert_ellipsis_to_idx, DeviceType, @@ -103,6 +102,11 @@ except ImportError: # torch 2.0 from torch._dynamo import is_compiling as is_dynamo_compiling +try: + from torch.nn.parameter import Buffer +except ImportError: + from tensordict.utils import Buffer + _register_tensor_class(ftdim.Tensor) __base__setattr__ = torch.nn.Module.__setattr__ @@ -247,8 +251,8 @@ def __init__( self._tensordict = _StringOnlyDict() - if names and is_dynamo_compiling(): - graph_break() + # if names and is_dynamo_compiling(): + # graph_break() has_device = device is not None sub_non_blocking = False call_sync = False @@ -2971,7 +2975,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, - names=copy(self._td_dim_names) if self._has_names() else None, + names=self._maybe_names(), ) # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too # This may be undesirable, not sure if this should be the default behaviour diff --git a/tensordict/base.py b/tensordict/base.py index d6e678775..8428e8f05 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -47,13 +47,13 @@ import orjson as json import torch +from tensordict._contextlib import LAST_OP_MAPS from tensordict.memmap import MemoryMappedTensor from tensordict.utils import ( _as_context_manager, _CloudpickleWrapper, _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, - _get_shape_from_args, _is_non_tensor, _is_number, _is_tensorclass, @@ -72,7 +72,6 @@ _td_fields, _unravel_key_to_tuple, _zip_strict, - Buffer, cache, convert_ellipsis_to_idx, DeviceType, @@ -104,6 +103,12 @@ from torch._dynamo import is_compiling as is_dynamo_compiling +try: + from torch.nn.parameter import Buffer +except ImportError: + from tensordict.utils import Buffer + + # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key)` is a valid usage. class _NoDefault(enum.IntEnum): @@ -436,7 +441,7 @@ def __delitem__(self, key: NestedKey) -> T: def __getstate__(self): result = dict(self.__dict__) - for key in ("_last_op", "_cache", "__last_op_queue", "__lock_parents_weakrefs"): + for key in ("_last_op", "_cache", "__lock_parents_weakrefs"): result.pop(key, None) return result @@ -444,7 +449,6 @@ def __setstate__(self, state): for key, value in state.items(): setattr(self, key, value) self._cache = None - self.__last_op_queue = None self._last_op = None if self._is_locked: # this can cause avoidable overhead, as we will be locking the leaves @@ -8602,18 +8606,9 @@ def _validate_value( self.names = value.names[: self.batch_dims] return value - # Context manager functionality - @property - def _last_op_queue(self): - # this is used to keep track of the last operation when using - # the tensordict as a context manager. - last_op_queue = self.__dict__.get("__last_op_queue") - if last_op_queue is None: - last_op_queue = collections.deque() - self.__dict__["__last_op_queue"] = last_op_queue - return last_op_queue - def __enter__(self): + if not hasattr(self, "_last_op_queue"): + self._last_op_queue = collections.deque() self._last_op_queue.append(self._last_op) return self @@ -8627,117 +8622,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): last_op, (args, kwargs, out) = _last_op # TODO: transpose, flatten etc. as decorator should lock the content to make sure that no key is # added or deleted - if last_op == type(self).lock_.__name__: - return self.unlock_() - elif last_op == type(self).unlock_.__name__: - return self.lock_() - elif last_op == type(self).transpose.__name__: - dim0, dim1 = args - if not out.is_locked: - return out.update(self.transpose(dim0, dim1), inplace=False) - else: - return out.update_(self.transpose(dim0, dim1)) - elif last_op == type(self).flatten_keys.__name__: - sep = args[0] if args else "." - if not out.is_locked: - return out.update(self.unflatten_keys(sep), inplace=False) - else: - return out.update_(self.unflatten_keys(sep)) - elif last_op == type(self).unflatten_keys.__name__: - sep = args[0] if args else "." - if not out.is_locked: - return out.update(self.flatten_keys(sep), inplace=False) - else: - return out.update_(self.flatten_keys(sep)) - elif last_op == type(self).flatten.__name__: - if len(args) == 2: - dim0, dim1 = args - elif len(args) == 1: - dim0 = args[0] - dim1 = kwargs.get("end_dim", -1) - else: - dim0 = kwargs.get("start_dim", 0) - dim1 = kwargs.get("end_dim", -1) - if dim1 < 0: - dim1 = out.ndim + dim1 - if dim0 < 0: - dim0 = out.ndim + dim0 - - if not out.is_locked: - return out.update( - self.unflatten(dim0, out.shape[dim0 : dim1 + 1]), inplace=False - ) - else: - return out.update_(self.unflatten(dim0, out.shape[dim0 : dim1 + 1])) - - elif last_op == type(self).unflatten.__name__: - if args: - dim0 = args[0] - if len(args) > 1: - unflattened_size = args[1] - else: - unflattened_size = kwargs.get("unflattened_size") - else: - dim0 = kwargs.get("dim") - unflattened_size = kwargs.get("unflattened_size") - if dim0 < 0: - dim0 = out.ndim + dim0 - dim1 = dim0 + len(unflattened_size) - 1 - if not out.is_locked: - unflattened = self.flatten(dim0, dim1) - return out.update(unflattened, inplace=False) - else: - unflattened = self.flatten(dim0, dim1) - return out.update_(unflattened) - - elif last_op == type(self).permute.__name__: - dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs) - dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list] - # inverse map - inv_dims_list = np.argsort(dims_list) - if not out.is_locked: - return out.update(self.permute(inv_dims_list), inplace=False) - else: - return out.update_(self.permute(inv_dims_list)) - elif last_op == type(self).view.__name__: - if not out.is_locked: - return out.update(self.view(out.shape), inplace=False) - else: - return out.update_(self.view(out.shape)) - elif last_op == type(self).unsqueeze.__name__: - if args: - (dim,) = args - elif kwargs: - dim = kwargs["dim"] - else: - raise RuntimeError( - "Cannot use td.unsqueeze() as a decorator if the dimension is implicit." - ) - if not out.is_locked: - return out.update(self.squeeze(dim), inplace=False) - else: - return out.update_(self.squeeze(dim)) - elif last_op == type(self).squeeze.__name__: - if args: - (dim,) = args - elif kwargs: - dim = kwargs["dim"] - else: - raise RuntimeError( - "Cannot use td.squeeze() as a decorator if the dimension is implicit." - ) - if not out.is_locked: - return out.update(self.unsqueeze(dim), inplace=False) - else: - return out.update_(self.unsqueeze(dim)) - elif last_op == type(self).to_module.__name__: - if is_tensor_collection(out): - with out.unlock_(): - return self.to_module(*args, **kwargs, swap_dest=out) - else: - raise RuntimeError( - "to_module cannot be used as a decorator when return_swap=False." - ) + _inv_caller = LAST_OP_MAPS.get(last_op) + if _inv_caller is not None: + return _inv_caller(self, args, kwargs, out) else: raise NotImplementedError(f"Unrecognised function {last_op}.") return self diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 2da21f53f..094992388 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -33,7 +33,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.utils import ( _LOCK_ERROR, - Buffer, + BufferLegacy, erase_cache, IndexType, lock_blocked, @@ -52,6 +52,11 @@ _has_funcdim = False +try: + from torch.nn.parameter import Buffer +except ImportError: + from tensordict.utils import Buffer + def _apply_leaves(data, fn): if isinstance(data, TensorDict): @@ -111,9 +116,14 @@ def _maybe_make_param_or_buffer(tensor): and not isinstance(tensor, nn.Parameter) and tensor.dtype in (torch.float, torch.double, torch.half) ): - # convert all non-parameters to buffers - # dataptr = tensor.data.data_ptr() - tensor = Buffer(tensor) + if tensor.grad_fn is None: + # convert all non-parameters to buffers + # dataptr = tensor.data.data_ptr() + tensor = Buffer(tensor) + else: + # We want to keep the grad_fn of tensors, e.g. param.expand(10) should point to the original param + tensor = BufferLegacy(tensor) + # assert tensor.data.data_ptr() == dataptr return tensor @@ -329,7 +339,6 @@ def __init__( self._reset_params() self._is_locked = False self._locked_tensordicts = [] - self.__last_op_queue = None self._get_post_hook = [] def register_get_post_hook(self, hook): @@ -609,7 +618,7 @@ def _clone(tensor, memo=memo): tensor.data.clone(), requires_grad=tensor.requires_grad ) else: - result = Buffer(tensor.data.clone(), requires_grad=tensor.requires_grad) + result = Buffer(tensor.data.clone()) memo[tensor] = result return result @@ -1203,7 +1212,7 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.data = buffer_applied out_buffer = buffer else: - out_buffer = Buffer(buffer_applied, buffer.requires_grad) + out_buffer = Buffer(buffer_applied) self._buffers[key] = out_buffer if buffer.grad is not None: diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 96b07ed4b..9eb58e5ee 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -12,18 +12,19 @@ from typing import Any, Callable, Dict, List, Optional from warnings import warn -from tensordict._contextlib import _DecoratorContextManager from tensordict.nn import CompositeDistribution from tensordict.nn.common import dispatch, TensorDictModule, TensorDictModuleBase from tensordict.nn.distributions import Delta, distributions_maps from tensordict.nn.sequence import TensorDictSequential -from tensordict.nn.utils import set_skip_existing +from tensordict.nn.utils import _set_skip_existing_None from tensordict.tensordict import TensorDictBase from tensordict.utils import _zip_strict, NestedKey from torch import distributions as D, Tensor +from torch.utils._contextlib import _DecoratorContextManager + __all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"] @@ -420,7 +421,7 @@ def SAMPLE_LOG_PROB_KEY(self): return self.log_prob_key @dispatch(auto_batch_size=False) - @set_skip_existing(None) + @_set_skip_existing_None() def forward( self, tensordict: TensorDictBase, @@ -469,9 +470,9 @@ def _dist_sample( interaction_type = self.default_interaction_type if interaction_type is InteractionType.DETERMINISTIC: - try: + if hasattr(dist, "deterministic_sample"): return dist.deterministic_sample - except AttributeError: + else: try: support = dist.support fallback = ( @@ -520,13 +521,15 @@ def _dist_sample( ) elif interaction_type is InteractionType.MEAN: - try: - return dist.mean - except (AttributeError, NotImplementedError): - if dist.has_rsample: - return dist.rsample((self.n_empirical_estimate,)).mean(0) - else: - return dist.sample((self.n_empirical_estimate,)).mean(0) + if hasattr(dist, "mean"): + try: + return dist.mean + except NotImplementedError: + pass + if dist.has_rsample: + return dist.rsample((self.n_empirical_estimate,)).mean(0) + else: + return dist.sample((self.n_empirical_estimate,)).mean(0) elif interaction_type is InteractionType.RANDOM: if dist.has_rsample: @@ -644,7 +647,7 @@ def build_dist_from_params(self, tensordict: TensorDictBase) -> D.Distribution: return dest_module.get_dist(tensordict) @dispatch(auto_batch_size=False) - @set_skip_existing(None) + @_set_skip_existing_None() def forward( self, tensordict: TensorDictBase, diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index d1ce32eae..152223a4e 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -28,7 +28,7 @@ _SKIP_EXISTING = False -from tensordict._contextlib import _DecoratorContextManager +from torch.utils._contextlib import _DecoratorContextManager def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor: @@ -386,7 +386,10 @@ def _rebuild_buffer(data, requires_grad, backward_hooks): # For backward compatibility in imports -from tensordict.utils import Buffer # noqa +try: + from torch.nn.parameter import Buffer # noqa +except ImportError: + from tensordict.utils import Buffer # noqa def _auto_make_functional(): diff --git a/tensordict/utils.py b/tensordict/utils.py index 70d1606bd..d75c00db6 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -17,7 +17,7 @@ import sys import time import warnings -from collections import defaultdict, OrderedDict +from collections import defaultdict from collections.abc import KeysView from copy import copy from functools import wraps @@ -46,16 +46,15 @@ unravel_key_list as unravel_key_list_cpp, unravel_keys as unravel_keys_cpp, ) -from tensordict._contextlib import _DecoratorContextManager from torch import Tensor from torch._C import _disabled_torch_function_impl from torch.nn.parameter import ( - _ParameterMeta, UninitializedBuffer, UninitializedParameter, UninitializedTensorMixin, ) +from torch.utils._contextlib import _DecoratorContextManager from torch.utils.data._utils.worker import _generate_state try: @@ -1820,51 +1819,54 @@ def _get_shape_from_args(*args, kwarg_name="size", **kwargs): return size -class Buffer(Tensor, metaclass=_ParameterMeta): - r"""A kind of Tensor that is to be considered a module buffer. +if hasattr(torch.nn, "Buffer"): + _parent_buffer_cls = torch.nn.Buffer - Args: - data (Tensor): buffer tensor. - requires_grad (bool, optional): if the buffer requires gradient. See - :ref:`locally-disable-grad-doc` for more details. Default: `False` - """ + class Buffer: # noqa: D101 + ... - def __new__(cls, data=None, requires_grad=False): - if data is None: - data = torch.empty(0) + class _BufferMeta: ... - if type(data) is torch.Tensor or type(data) is Buffer: - return data.as_subclass(cls) +else: - # Path for custom tensors: set a flag on the instance to indicate parameter-ness. - if requires_grad: - t = data.detach().requires_grad_(requires_grad) - else: - t = data - t._is_buffer = True - return t + class _BufferMeta(torch._C._TensorMeta): + # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. + def __instancecheck__(self, instance): + if self is Buffer: + if isinstance(instance, torch.Tensor) and getattr( + instance, "_is_buffer", False + ): + return True + return super().__instancecheck__(instance) - def __deepcopy__(self, memo): - if id(self) in memo: - return memo[id(self)] - else: - result = type(self)( - self.data.clone(memory_format=torch.preserve_format), self.requires_grad - ) - memo[id(self)] = result - return result + class Buffer(torch.Tensor, metaclass=_BufferMeta): + """A replicate of torch.nn.Buffer if not available (prior to torch v2.5).""" - def __repr__(self): - return "Buffer containing:\n" + super(Buffer, self).__repr__() + def __new__(cls, data=None, *, persistent=True): + if data is None: + data = torch.empty(0) - def __reduce_ex__(self, proto): - # See Note [Don't serialize hooks] - return ( - torch._utils._rebuild_parameter, - (self.data, self.requires_grad, OrderedDict()), - ) + t = data.detach().requires_grad_(data.requires_grad) + t.persistent = persistent + t._is_buffer = True + return t + + __torch_function__ = _disabled_torch_function_impl + + _parent_buffer_cls = Buffer + + +class BufferLegacy(_parent_buffer_cls): + """A buffer subclass that keeps the grad fn history.""" - __torch_function__ = _disabled_torch_function_impl + def __new__(cls, data=None, *, persistent=True): + if data is None: + data = torch.empty(0) + + t = data + t.persistent = persistent + t._is_buffer = True + return t def _getitem_batch_size(batch_size, index): diff --git a/test/test_compile.py b/test/test_compile.py index 0bf8dabf3..1adcd2a69 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -222,8 +222,8 @@ def make_td_with_names(data): make_td_with_names, fullgraph=True, mode=mode ) make_td_with_names(data_dict) - with pytest.raises(torch._dynamo.exc.Unsupported): - make_td_with_names_c(data_dict) + # with pytest.raises(torch._dynamo.exc.Unsupported): + make_td_with_names_c(data_dict) @pytest.mark.skipif( not torch.cuda.is_available(), reason="cuda required to test device casting" @@ -623,25 +623,12 @@ def test_functional_error(self, mode): td_zero = TensorDictParams(td.data.clone()) td_zero.zero_() - def call(x, td): - with td.to_module(module): - return module(x) - - call_compile = torch.compile(call, fullgraph=True, mode=mode) - x = torch.randn(2, 3) - with pytest.raises( - torch._dynamo.exc.Unsupported, - match="UserDefinedObjectVariable|Unsupported context manager", - ): - call_compile(x, td_zero) os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "0" try: def call(x, td): - params = td.to_module(module, return_swap=True) - result = module(x) - params.to_module(module, return_swap=True, swap_dest=td) - return result + with td.to_module(module): + return module(x) call_compile = torch.compile(call, fullgraph=True, mode=mode) x = torch.randn(2, 3) @@ -683,14 +670,8 @@ def forward(self, x): td_zero.zero_() def call(x, td): - # TOFIX: `with` needs registering - # with td.to_module(module): - # return module(x) - - params = td.to_module(module, return_swap=True) - result = module(x) - params.to_module(module, return_swap=True, swap_dest=td) - return result + with td.to_module(module): + return module(x) call_compile = torch.compile(call, fullgraph=True, mode=mode) x = torch.randn(2, 3) diff --git a/test/test_nn.py b/test/test_nn.py index 220f3eacd..a3d4bd2a8 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -44,10 +44,10 @@ from tensordict.nn.utils import ( _set_auto_make_functional, _set_dispatch_td_nn_modules, - Buffer, set_skip_existing, skip_existing, ) + from torch import distributions, nn from torch.distributions import Normal from torch.utils._pytree import tree_map @@ -66,6 +66,10 @@ except ImportError as err: _has_functorch = False FUNCTORCH_ERR = str(err) +try: + from torch.nn.parameter import Buffer +except ImportError: + from tensordict.utils import Buffer # Capture all warnings diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f31c8988e..583c592cf 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1080,7 +1080,7 @@ def exec_module(params, x): y = torch.vmap(exec_module, (0, None))(params, x) y.sum().backward() for k, p in modules[0].named_parameters(): - assert p.grad is None if k.startswith("1") else p.grad is not None + assert p.grad is None if k.startswith("1") else p.grad is not None, k assert all( param.grad is not None for param in params.values(True, True)