diff --git a/.github/scripts/version_script.bat b/.github/scripts/version_script.bat index 8b9fc61a4..3c0db9269 100644 --- a/.github/scripts/version_script.bat +++ b/.github/scripts/version_script.bat @@ -1,3 +1,3 @@ @echo off -set TENSORDICT_BUILD_VERSION=0.6.2 +set TENSORDICT_BUILD_VERSION=0.7.0 echo TENSORDICT_BUILD_VERSION is set to %TENSORDICT_BUILD_VERSION% diff --git a/.github/scripts/version_script.sh b/.github/scripts/version_script.sh index cf49f554d..c016dcbc4 100644 --- a/.github/scripts/version_script.sh +++ b/.github/scripts/version_script.sh @@ -1,3 +1,3 @@ #!/bin/bash -export TENSORDICT_BUILD_VERSION=0.6.2 +export TENSORDICT_BUILD_VERSION=0.7.0 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 394d53b43..75f03f10f 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,8 +23,11 @@ jobs: strategy: matrix: python_version: ["3.10"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: repository: pytorch/tensordict upload-artifact: docs diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5fd055f85..1f7c5a47d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -18,6 +18,9 @@ concurrency: jobs: python-source-and-configs: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: repository: pytorch/tensordict script: | @@ -46,6 +49,9 @@ jobs: c-source: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: repository: pytorch/tensordict script: | diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 4bc66eab5..c4e679841 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -23,9 +23,12 @@ jobs: strategy: matrix: python_version: ["3.10"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: linux.g5.4xlarge.nvidia.gpu repository: pytorch/tensordict @@ -57,6 +60,9 @@ jobs: python_version: ["3.9", "3.10", "3.11", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: linux.12xlarge repository: pytorch/tensordict @@ -81,9 +87,12 @@ jobs: strategy: matrix: python_version: ["3.10"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: linux.g5.4xlarge.nvidia.gpu repository: pytorch/tensordict @@ -116,6 +125,9 @@ jobs: python_version: ["3.9", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: linux.12xlarge repository: pytorch/tensordict diff --git a/.github/workflows/test-rl-gpu.yml b/.github/workflows/test-rl-gpu.yml index e38059a94..57201eba7 100644 --- a/.github/workflows/test-rl-gpu.yml +++ b/.github/workflows/test-rl-gpu.yml @@ -23,9 +23,12 @@ jobs: strategy: matrix: python_version: ["3.10"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: linux.g5.4xlarge.nvidia.gpu repository: pytorch/tensordict diff --git a/.github/workflows/wheels-windows.yml b/.github/workflows/wheels-windows.yml index 75044091f..15032eef3 100644 --- a/.github/workflows/wheels-windows.yml +++ b/.github/workflows/wheels-windows.yml @@ -34,7 +34,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - TENSORDICT_BUILD_VERSION=0.6.2 python3 setup.py bdist_wheel + TENSORDICT_BUILD_VERSION=0.7.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v4 with: diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 0be9b9883..47c0ee725 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -152,7 +152,7 @@ to build distributions from network outputs and get summary statistics or sample >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule - >>> from tensordict.nn.distributions import NormalParamWrapper + >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.prototype import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, @@ -161,9 +161,9 @@ to build distributions from network outputs and get summary statistics or sample >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] ... ) - >>> net = torch.nn.GRUCell(4, 8) + >>> net = torch.nn.Sequential(torch.nn.GRUCell(4, 8), NormalParamExtractor()) >>> module = TensorDictModule( - ... NormalParamWrapper(net), in_keys=["input", "hidden"], out_keys=["loc", "scale"] + ... net, in_keys=["input", "hidden"], out_keys=["loc", "scale"] ... ) >>> prob_module = ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], @@ -194,6 +194,7 @@ to build distributions from network outputs and get summary statistics or sample TensorDictModuleBase TensorDictModule ProbabilisticTensorDictModule + ProbabilisticTensorDictSequential TensorDictSequential TensorDictModuleWrapper CudaGraphModule @@ -257,6 +258,10 @@ Distributions NormalParamExtractor OneHotCategorical TruncatedNormal + InteractionType + set_interaction_type + add_custom_mapping + mappings Utils @@ -270,8 +275,8 @@ Utils make_tensordict dispatch - set_interaction_type inv_softplus biased_softplus set_skip_existing skip_existing + rand_one_hot diff --git a/setup.py b/setup.py index 1a84252b6..986752654 100644 --- a/setup.py +++ b/setup.py @@ -66,10 +66,10 @@ def _get_pytorch_version(is_nightly, is_local): # if "PYTORCH_VERSION" in os.environ: # return f"torch=={os.environ['PYTORCH_VERSION']}" if is_nightly: - return "torch>=2.6.0.dev" + return "torch>=2.7.0.dev" if is_local: return "torch" - return "torch>=2.5.0" + return "torch>=2.6.0" def _get_packages(): diff --git a/tensordict/_td.py b/tensordict/_td.py index 0ec6616eb..0147d63ca 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -315,7 +315,9 @@ def _new_unsafe( nested: bool = True, **kwargs: dict[str, Any] | None, ) -> TensorDict: - if is_compiling(): + if is_compiling() and cls is TensorDict: + # If the cls is not TensorDict, we must escape this to keep the same class. + # That's unfortunate because as of now it graph breaks but that's the best we can do. return TensorDict( source, batch_size=batch_size, @@ -2195,8 +2197,7 @@ def from_dict_instance( input_dict = copy(input_dict) for key, value in list(input_dict.items()): if isinstance(value, (dict,)): - # TODO: v0.7: remove the None - cur_value = self.get(key, None) + cur_value = self.get(key) if cur_value is not None: input_dict[key] = cur_value.from_dict_instance( value, diff --git a/tensordict/base.py b/tensordict/base.py index 83ab269bc..a957e69ec 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -157,7 +157,7 @@ def __bool__(self): if "TD_GET_DEFAULTS_TO_NONE" in os.environ: _GET_DEFAULTS_TO_NONE = strtobool(os.environ["TD_GET_DEFAULTS_TO_NONE"]) else: - _GET_DEFAULTS_TO_NONE = None + _GET_DEFAULTS_TO_NONE = True def set_get_defaults_to_none(set_to_none: bool = True): @@ -172,7 +172,7 @@ def set_get_defaults_to_none(set_to_none: bool = True): """ global _GET_DEFAULTS_TO_NONE - _GET_DEFAULTS_TO_NONE = set_to_none + _GET_DEFAULTS_TO_NONE = bool(set_to_none) def get_defaults_to_none(set_to_none: bool = True): @@ -6390,22 +6390,19 @@ def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: Args: key (str, tuple of str): key to be queried. If tuple of str it is equivalent to chained calls of getattr. - default: default value if the key is not found in the tensordict. + default: default value if the key is not found in the tensordict. Defaults to ``None``. .. warning:: - Currently, if a key is not present in the tensordict and no default - is passed, a `KeyError` is raised. From v0.7, this behaviour will be changed - and a `None` value will be returned instead. To adopt the new behaviour, - set the environment variable `export TD_GET_DEFAULTS_TO_NONE='1'` or call - :func`~tensordict.set_get_defaults_to_none`. + Previously, if a key was not present in the tensordict and no default + was passed, a `KeyError` was raised. From v0.7, this behaviour has been changed + and a `None` value is returned instead (in accordance with the what dict.get behavior). + To adopt the old behavior, set the environment variable `export TD_GET_DEFAULTS_TO_NONE='0'` or call + :func`~tensordict.set_get_defaults_to_none(False)`. Examples: >>> td = TensorDict({"x": 1}, batch_size=[]) >>> td.get("x") tensor(1) - >>> set_get_defaults_to_none(False) # Current default behaviour - >>> td.get("y") # Raises KeyError - >>> set_get_defaults_to_none(True) >>> td.get("y") None """ @@ -6413,37 +6410,19 @@ def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: if not key: raise KeyError(_GENERIC_NESTED_ERR.format(key)) # Find what the default is - has_default = False if args: default = args[0] if len(args) > 1 or kwargs: raise TypeError("only one (keyword) argument is allowed.") - has_default = True elif kwargs: default = kwargs.pop("default") if args or kwargs: raise TypeError("only one (keyword) argument is allowed.") - has_default = True elif _GET_DEFAULTS_TO_NONE: default = None else: default = NO_DEFAULT - try: - return self._get_tuple(key, default=default) - except KeyError: - if _GET_DEFAULTS_TO_NONE is None and not has_default: - # We raise an exception AND a warning because we want the user to know that this exception will - # not be raised in the future - warnings.warn( - f"The entry ({key}) you have queried with `get` is not present in the tensordict. " - "Currently, this raises an exception. " - "To align with `dict.get`, this behaviour will be changed in v0.7 and a `None` value will " - "be returned instead (no error will be raised). " - "To suppress this warning and use the new behaviour (recommended), call `tensordict.set_get_defaults_to_none(True)` or set the env variable `export TD_GET_DEFAULTS_TO_NONE='1'`. " - "To suppress this warning and keep the old behaviour, call `tensordict.set_get_defaults_to_none(False)` or set the env variable `export TD_GET_DEFAULTS_TO_NONE='0'`.", - category=DeprecationWarning, - ) - raise + return self._get_tuple(key, default=default) @abc.abstractmethod def _get_str(self, key, default): ... @@ -11119,6 +11098,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): # During exit, updates mustn't be made in-place as the source and dest # storage location can be identical, resulting in a RuntimeError + if is_compiling(): + self.clear_refs_for_compile_() if exc_type is not None and issubclass(exc_type, Exception): return False _last_op = self._last_op_queue.pop() @@ -11128,11 +11109,27 @@ def __exit__(self, exc_type, exc_val, exc_tb): # added or deleted _inv_caller = LAST_OP_MAPS.get(last_op) if _inv_caller is not None: - return _inv_caller(self, args, kwargs, out_wr()) + prev_ref = out_wr() + return _inv_caller(self, args, kwargs, prev_ref) else: raise NotImplementedError(f"Unrecognised function {last_op}.") return self + def clear_refs_for_compile_(self) -> T: + """Clears the weakrefs in order for the tensordict to get out of the compile region safely. + + Use this whenever you hit `torch._dynamo.exc.Unsupported: reconstruct: WeakRefVariable()` + before returning a TensorDict. + + Returns: self + """ + self._last_op = None + for v in self.values(True, True, is_leaf=_is_tensor_collection): + if _is_tensorclass(type(v)): + v = v._tensordict + v._last_op = None + return self + # Clone, select, exclude, empty def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -> T: """Selects the keys of the tensordict and returns a new tensordict with only the selected keys. @@ -11559,7 +11556,11 @@ def from_any( device=device, batch_size=batch_size, ) - if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): + if ( + isinstance(obj, np.ndarray) + and hasattr(obj.dtype, "names") + and obj.dtype.names is not None + ): return cls.from_struct_array( obj, auto_batch_size=auto_batch_size, diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 82f4eadf7..de28f6d49 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1084,14 +1084,12 @@ def forward( if self._kwargs is not None: kwargs.update( { - # TODO: v0.7: remove the None - kwarg: tensordict.get(in_key, None) + kwarg: tensordict.get(in_key) for kwarg, in_key in _zip_strict(self._kwargs, self.in_keys) } ) tensors = () else: - # TODO: v0.7: remove the None tensors = tuple( tensordict._get_tuple_maybe_non_tensor( _unravel_key_to_tuple(in_key), None @@ -1121,8 +1119,7 @@ def forward( keys = unravel_key_list(list(tensors.keys())) values = tensors.values() tensors = dict(_zip_strict(keys, values)) - # TODO: v0.7: remove the None - tensors = tuple(tensors.get(key, None) for key in self.out_keys) + tensors = tuple(tensors.get(key) for key in self.out_keys) if not isinstance(tensors, tuple): tensors = (tensors,) tensordict_out = self._write_to_tensordict( diff --git a/tensordict/nn/distributions/__init__.py b/tensordict/nn/distributions/__init__.py index e626d8a64..c1f0a57f6 100644 --- a/tensordict/nn/distributions/__init__.py +++ b/tensordict/nn/distributions/__init__.py @@ -10,7 +10,6 @@ AddStateIndependentNormalScale, Delta, NormalParamExtractor, - NormalParamWrapper, ) from tensordict.nn.distributions.discrete import OneHotCategorical, rand_one_hot from tensordict.nn.distributions.truncated_normal import TruncatedNormal diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 242b6cfa1..c64f21002 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -127,8 +127,7 @@ def __init__( else: write_name = name_unravel name = name_unravel - # TODO: v0.7: remove the None - dist_params = params.get(name, None) + dist_params = params.get(name) kwargs = extra_kwargs.get(name, {}) if dist_params is None: raise KeyError @@ -587,8 +586,7 @@ def icdf(self, sample: TensorDictBase) -> TensorDictBase: KeyError: If neither `` nor `_cdf` can be found in the input TensorDict for a component distribution. """ for name, dist in self.dists.items(): - # TODO: v0.7: remove the None - prob = sample.get(_add_suffix(name, "_cdf"), None) + prob = sample.get(_add_suffix(name, "_cdf")) if prob is None: try: prob = self.cdf(sample.get(name)) diff --git a/tensordict/nn/distributions/continuous.py b/tensordict/nn/distributions/continuous.py index d3ec0a418..2c1c17c55 100644 --- a/tensordict/nn/distributions/continuous.py +++ b/tensordict/nn/distributions/continuous.py @@ -5,7 +5,6 @@ from __future__ import annotations -import warnings from numbers import Number from typing import Sequence @@ -19,7 +18,6 @@ # We need this to build the distribution maps __all__ = [ "NormalParamExtractor", - "NormalParamWrapper", "AddStateIndependentNormalScale", "Delta", ] @@ -29,59 +27,15 @@ class NormalParamWrapper(nn.Module): - """A wrapper for normal distribution parameters. - - Args: - operator (nn.Module): operator whose output will be transformed_in in location and scale parameters - scale_mapping (str, optional): positive mapping function to be used with the std. - default = "biased_softplus_1.0" (i.e. softplus map with bias such that fn(0.0) = 1.0) - choices: "softplus", "exp", "relu", "biased_softplus_1"; - scale_lb (Number, optional): The minimum value that the variance can take. Default is 1e-4. - - Examples: - >>> from torch import nn - >>> import torch - >>> module = nn.Linear(3, 4) - >>> module_normal = NormalParamWrapper(module) - >>> tensor = torch.randn(3) - >>> loc, scale = module_normal(tensor) - >>> print(loc.shape, scale.shape) - torch.Size([2]) torch.Size([2]) - >>> assert (scale > 0).all() - >>> # with modules that return more than one tensor - >>> module = nn.LSTM(3, 4) - >>> module_normal = NormalParamWrapper(module) - >>> tensor = torch.randn(4, 2, 3) - >>> loc, scale, others = module_normal(tensor) - >>> print(loc.shape, scale.shape) - torch.Size([4, 2, 2]) torch.Size([4, 2, 2]) - >>> assert (scale > 0).all() - - """ - def __init__( self, operator: nn.Module, scale_mapping: str = "biased_softplus_1.0", scale_lb: Number = 1e-4, ) -> None: - warnings.warn( - "The NormalParamWrapper class will be deprecated in v0.7 in favor of :class:`~tensordict.nn.NormalParamExtractor`.", - category=DeprecationWarning, + raise RuntimeError( + "NormalParamWrapper has been deprecated in favor of `tensordict.nn.NormalParamExtractor`. Use this class instead." ) - super().__init__() - self.operator = operator - self.scale_mapping = scale_mapping - self.scale_lb = scale_lb - - def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: - net_output = self.operator(*tensors) - others = () - if not isinstance(net_output, torch.Tensor): - net_output, *others = net_output - loc, scale = net_output.chunk(2, -1) - scale = mappings(self.scale_mapping)(scale).clamp_min(self.scale_lb) - return (loc, scale, *others) class NormalParamExtractor(nn.Module): diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 1633b01ad..9d61d330b 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -453,7 +453,8 @@ def log_prob_key(self): f"Currently, it is assumed that composite_lp_aggregate() will return True: the log-probs will be aggregated " f"in a {self._log_prob_key} entry. From v0.9, this behaviour will be changed and individual log-probs will " f"be written in `('path', 'to', 'leaf', '_log_prob')`. To prepare for this change, " - f"call `set_composite_lp_aggregate(mode: bool).set()` at the beginning of your script. Use mode=True " + f"call `set_composite_lp_aggregate(mode: bool).set()` at the beginning of your script (or set the " + f"COMPOSITE_LP_AGGREGATE env variable). Use mode=True " f"to keep the current behaviour, and mode=False to use per-leaf log-probs.", category=DeprecationWarning, ) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 60920f23a..f96d341e9 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -71,12 +71,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.softplus(x + self.bias) + self.min_val +def expln(x): + """A smooth, continuous positive mapping presented in "State-Dependent Exploration for Policy Gradient Methods". + + https://people.idsia.ch/~juergen/ecml2008rueckstiess.pdf + + """ + out = torch.empty_like(x) + idx_neg = x <= 0 + out[idx_neg] = x[idx_neg].exp() + out[~idx_neg] = x[~idx_neg].log1p() + 1 + return out + + _MAPPINGS: dict[str, Callable[[torch.Tensor], torch.Tensor]] = { "softplus": torch.nn.functional.softplus, "exp": torch.exp, "relu": torch.relu, "biased_softplus": biased_softplus(1.0), "none": lambda x: x, + "expln": expln, } @@ -450,7 +464,13 @@ def _generate_next_value_(name, start, count, last_values): return name.lower() -_composite_lp_aggregate = _ContextManager() +_composite_lp_aggregate = _ContextManager( + default=( + strtobool(os.getenv("COMPOSITE_LP_AGGREGATE")) + if os.getenv("COMPOSITE_LP_AGGREGATE") is not None + else None + ) +) def composite_lp_aggregate(nowarn: bool = False) -> bool | None: @@ -467,9 +487,9 @@ def composite_lp_aggregate(nowarn: bool = False) -> bool | None: if not nowarn: warnings.warn( "Composite log-prob aggregation wasn't defined explicitly and ``composite_lp_aggregate()`` will " - "currently return ``True``. However, from v0.9, this behaviour will change and ``composite_lp_aggregate`` will " + "currently return ``True``. However, from v0.9, this behavior will change and ``composite_lp_aggregate`` will " "return ``False``. Please change your code accordingly by specifying the aggregation strategy via " - "`tensordict.nn.set_composite_lp_aggregate`.", + "`tensordict.nn.set_composite_lp_aggregate` or via the `COMPOSITE_LP_AGGREGATE` environment variable.", category=DeprecationWarning, ) return True @@ -483,6 +503,8 @@ class set_composite_lp_aggregate(_DecoratorContextManager): will be summed into a single tensor with the shape of the root tensordict. This behaviour is being deprecated in favor of non-aggregated log-probs, which offer more flexibility and a somewhat more natural API (tensordict samples, tensordict log-probs, tensordict entropies). + The value of composite_lp_aggregate can also be controlled through the `COMPOSITE_LP_AGGREGATE` environment variable. + Example: >>> _ = torch.manual_seed(0) >>> from tensordict import TensorDict diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 2a3957a93..1b66f6041 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -344,8 +344,7 @@ def _process_array(self, key, array): ) return out else: - # TODO: remove the None in v0.7 - out = self._nested_tensordicts.get(key, None) + out = self._nested_tensordicts.get(key) if out is None: out = self._nested_tensordicts[key] = PersistentTensorDict( group=array, @@ -398,8 +397,7 @@ def get_at( return out.pin_memory() return out elif array is not default: - # TODO: remove the None in v0.7 - out = self._nested_tensordicts.get(key, None) + out = self._nested_tensordicts.get(key) if out is None: out = self._nested_tensordicts[key] = PersistentTensorDict( group=array, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0f4e43988..60f51ef26 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -156,6 +156,7 @@ def __subclasscheck__(self, subclass): "batch_size", "bytes", "cat_tensors", + "clear_refs_for_compile_", "data_ptr", "depth", "dim", @@ -2276,8 +2277,7 @@ def _get_at(self, key: NestedKey, idx, default: Any = NO_DEFAULT): def _data(self): # We allow data to be a field of the class too if "data" in self.__dataclass_fields__: - # TODO: remove the None in v0.7 - data = self._tensordict.get("data", None) + data = self._tensordict.get("data") if data is None: data = self._non_tensordict.get("data") return data @@ -3562,6 +3562,10 @@ def _from_list(cls, datalist: List, device: torch.device, ndim: int | None = Non stack_dim=0, ) + def densify(self, layout: torch.layout = torch.strided): + # No need to do anything with a non tensor stack + return self + def update( self, input_dict_or_td: dict[str, CompatibleType] | T, diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index f4131802d..a9a443ac2 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -436,6 +436,7 @@ class TensorClass: @device.setter def device(self, value: DeviceType) -> torch.device | None: ... def clear(self) -> T: ... + def clear_refs_for_compile_(self) -> T: ... @classmethod def fromkeys(cls, keys: list[NestedKey], value: Any = 0): ... def popitem(self) -> tuple[NestedKey, CompatibleType]: ... diff --git a/tensordict/utils.py b/tensordict/utils.py index c2fd31ff4..c019092bf 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1217,6 +1217,10 @@ def new_func(self, *args, **kwargs): return new_func +def _strong_ref(self): + return lambda self: self + + def _as_context_manager(attr=None): """Converts a method to a decorator. @@ -1238,12 +1242,13 @@ def func_as_decorator(_self, *args, **kwargs): _attr_post = getattr(_self, attr) if out is not None: if _attr_post is not _attr_pre: + ref = weakref.ref(_self) out._last_op = ( func.__name__, ( args, kwargs, - weakref.ref(_self), + ref, ), ) else: @@ -1256,7 +1261,8 @@ def func_as_decorator(_self, *args, **kwargs): def func_as_decorator(_self, *args, **kwargs): out = func(_self, *args, **kwargs) if out is not None: - out._last_op = (func.__name__, (args, kwargs, weakref.ref(_self))) + ref = weakref.ref(_self) + out._last_op = (func.__name__, (args, kwargs, ref)) return out return func_as_decorator @@ -1485,8 +1491,7 @@ def _default_hook(td: T, key: tuple[str, ...]) -> None: For example, ``td.set(("a", "b"))`` may require to create ``"a"``. """ - # TODO: remove the None in v0.7 - out = td.get(key[0], None) + out = td.get(key[0]) if out is None: td._create_nested_str(key[0]) out = td._get_str(key[0], None) @@ -1502,8 +1507,7 @@ def _get_leaf_tensordict( if hook is not None: tensordict = hook(tensordict, key) else: - # TODO: remove the None in v0.7 - tensordict = tensordict.get(key[0], default=None) + tensordict = tensordict.get(key[0]) if tensordict is None: raise KeyError(f"No sub-tensordict with key {key[0]}.") key = key[1:] @@ -2261,9 +2265,8 @@ def isin( >>> torch.testing.assert_close(in_reference, expected_in_reference) """ # Get the data - # TODO: remove the None in v0.7 - reference_tensor = reference.get(key, default=None) - target_tensor = input.get(key, default=None) + reference_tensor = reference.get(key) + target_tensor = input.get(key) # Check key is present in both tensordict and reference_tensordict if not isinstance(target_tensor, torch.Tensor): @@ -2359,8 +2362,7 @@ def remove_duplicates( ... ) >>> assert (td == expected_output).all() """ - # TODO: remove the None in v0.7 - tensor = input.get(key, default=None) + tensor = input.get(key) # Check if the key is a TensorDict if tensor is None: diff --git a/test/test_compile.py b/test/test_compile.py index bb07900a9..2d6792348 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -144,7 +144,8 @@ def reshape(td): def test_view(self, mode): def view(td): - return td.view(2, 2) + out = td.view(2, 2).clear_refs_for_compile_() + return out view_c = torch.compile(view, fullgraph=True, mode=mode) data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) @@ -155,7 +156,7 @@ def view(td): def test_transpose(self, mode): def transpose(td): - return td.transpose(0, 1) + return td.transpose(0, 1).clear_refs_for_compile_() transpose_c = torch.compile(transpose, fullgraph=True, mode=mode) data = TensorDict({"a": {"b": torch.arange(6).view(2, 3)}}, [2, 3]) @@ -211,7 +212,7 @@ def clone(td: TensorDict): @pytest.mark.parametrize("recurse", [True, False]) def test_flatten_keys(self, recurse, mode): def flatten_keys(td: TensorDict): - return td.flatten_keys() + return td.flatten_keys().clear_refs_for_compile_() flatten_keys_c = torch.compile(flatten_keys, fullgraph=True, mode=mode) data = TensorDict({"a": {"b": 0, "c": 1}}) @@ -225,7 +226,7 @@ def flatten_keys(td: TensorDict): @pytest.mark.parametrize("recurse", [True, False]) def test_unflatten_keys(self, recurse, mode): def unflatten_keys(td: TensorDict): - return td.unflatten_keys() + return td.unflatten_keys().clear_refs_for_compile_() unflatten_keys_c = torch.compile(unflatten_keys, fullgraph=True, mode=mode) data = TensorDict({"a.b": 0, "a.c": 1}) @@ -280,7 +281,7 @@ def locked_op(td): # Adding stuff uses cache, check that this doesn't break td2 = td + 1 td3 = td + td2 - return td3 + return td3.clear_refs_for_compile_() td = TensorDict( {"a": torch.randn(1, 2, 3), "b": torch.zeros(1, 2, 3, dtype=torch.bool)}, @@ -586,6 +587,29 @@ def locked_op(tc): tc_op_c = locked_op_c(data) assert (tc_op == tc_op_c).all() + def test_td_new_unsafe(self, mode): + + class MyTd(TensorDict): + pass + + def func_td(): + return TensorDict._new_unsafe(a=torch.randn(3), batch_size=torch.Size(())) + + @torch.compile(fullgraph=True, mode=mode) + def func_c_td(): + return TensorDict._new_unsafe(a=torch.randn(3), batch_size=torch.Size(())) + + def func_mytd(): + return MyTd._new_unsafe(a=torch.randn(3), batch_size=torch.Size(())) + + # This will graph break + @torch.compile(mode=mode) + def func_c_mytd(): + return MyTd._new_unsafe(a=torch.randn(3), batch_size=torch.Size(())) + + assert type(func_td()) is type(func_c_td()) + assert type(func_mytd()) is type(func_c_mytd()) + @pytest.mark.skipif( TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" @@ -761,7 +785,9 @@ def forward(self, x): def call(x, td): with td.to_module(module): - return module(x) + y = module(x) + td.clear_refs_for_compile_() + return y call_compile = torch.compile(call, fullgraph=True, mode=mode) x = torch.randn(2, 3) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index cac9cab76..29831ccd1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3794,8 +3794,7 @@ def get_old_val(newval, oldval): td_1 = td.apply(get_old_val, td_c, inplace=inplace, default=None) if inplace: for key in td.keys(True, True): - # TODO: remove default None in v0.7 - td_c_val = td_c.get(key, None) + td_c_val = td_c.get(key) if td_c_val is not None: assert (td_c[key] == td[key]).all() else: @@ -3804,8 +3803,7 @@ def get_old_val(newval, oldval): assert (td_1[key] == td[key]).all() else: for key in td.keys(True, True): - # TODO: remove default None in v0.7 - td_c_val = td_c.get(key, None) + td_c_val = td_c.get(key) if td_c_val is not None: assert (td_c[key] == td_1[key]).all() else: @@ -6558,8 +6556,7 @@ def test_sorted_keys(self, td_name, device): assert key1 == key2 assert i == len(td.keys()) - 1 if td.is_locked: - # TODO: remove default None in v0.7 - assert td._cache.get("sorted_keys", None) is not None + assert td._cache.get("sorted_keys") is not None td.unlock_() assert td._cache is None elif td_name not in ("sub_td", "sub_td2"): # we cannot lock sub tensordicts @@ -6570,8 +6567,7 @@ def test_sorted_keys(self, td_name, device): assert target._cache is None td.lock_() _ = td.sorted_keys - # TODO: remove default None in v0.7 - assert target._cache.get("sorted_keys", None) is not None + assert target._cache.get("sorted_keys") is not None td.unlock_() assert target._cache is None @@ -9795,7 +9791,7 @@ def run_assertions(): assert get_defaults_to_none() run_assertions() set_get_defaults_to_none(None) - assert get_defaults_to_none() is None + assert get_defaults_to_none() is False run_assertions() finally: set_get_defaults_to_none(set_back) @@ -9819,7 +9815,7 @@ def run_assertions(): assert get_defaults_to_none() run_assertions() set_get_defaults_to_none(None) - assert get_defaults_to_none() is None + assert get_defaults_to_none() is False run_assertions() finally: set_get_defaults_to_none(set_back) @@ -9846,7 +9842,7 @@ def run_assertions(): assert get_defaults_to_none() run_assertions() set_get_defaults_to_none(None) - assert get_defaults_to_none() is None + assert get_defaults_to_none() is False run_assertions() finally: set_get_defaults_to_none(set_back) @@ -9873,7 +9869,7 @@ def run_assertions(): assert get_defaults_to_none() run_assertions() set_get_defaults_to_none(None) - assert get_defaults_to_none() is None + assert get_defaults_to_none() is False run_assertions() finally: set_get_defaults_to_none(set_back) @@ -9900,7 +9896,7 @@ def run_assertions(): assert get_defaults_to_none() run_assertions() set_get_defaults_to_none(None) - assert get_defaults_to_none() is None + assert get_defaults_to_none() is False run_assertions() finally: set_get_defaults_to_none(set_back) @@ -9927,7 +9923,7 @@ def run_assertions(): assert get_defaults_to_none() run_assertions() set_get_defaults_to_none(None) - assert get_defaults_to_none() is None + assert get_defaults_to_none() is False run_assertions() finally: set_get_defaults_to_none(set_back) diff --git a/version.txt b/version.txt index b61604874..faef31a43 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.6.2 +0.7.0