Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 4, 2025
2 parents d7b479c + 9113311 commit 9197097
Show file tree
Hide file tree
Showing 24 changed files with 177 additions and 148 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/version_script.bat
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
@echo off
set TENSORDICT_BUILD_VERSION=0.6.2
set TENSORDICT_BUILD_VERSION=0.7.0
echo TENSORDICT_BUILD_VERSION is set to %TENSORDICT_BUILD_VERSION%
2 changes: 1 addition & 1 deletion .github/scripts/version_script.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

export TENSORDICT_BUILD_VERSION=0.6.2
export TENSORDICT_BUILD_VERSION=0.7.0
5 changes: 4 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ jobs:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
repository: pytorch/tensordict
upload-artifact: docs
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ concurrency:
jobs:
python-source-and-configs:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
repository: pytorch/tensordict
script: |
Expand Down Expand Up @@ -46,6 +49,9 @@ jobs:
c-source:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
repository: pytorch/tensordict
script: |
Expand Down
16 changes: 14 additions & 2 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ jobs:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: linux.g5.4xlarge.nvidia.gpu
repository: pytorch/tensordict
Expand Down Expand Up @@ -57,6 +60,9 @@ jobs:
python_version: ["3.9", "3.10", "3.11", "3.12"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: linux.12xlarge
repository: pytorch/tensordict
Expand All @@ -81,9 +87,12 @@ jobs:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: linux.g5.4xlarge.nvidia.gpu
repository: pytorch/tensordict
Expand Down Expand Up @@ -116,6 +125,9 @@ jobs:
python_version: ["3.9", "3.12"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: linux.12xlarge
repository: pytorch/tensordict
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/test-rl-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ jobs:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: linux.g5.4xlarge.nvidia.gpu
repository: pytorch/tensordict
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
shell: bash
run: |
python3 -mpip install wheel
TENSORDICT_BUILD_VERSION=0.6.2 python3 setup.py bdist_wheel
TENSORDICT_BUILD_VERSION=0.7.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v4
with:
Expand Down
13 changes: 9 additions & 4 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ to build distributions from network outputs and get summary statistics or sample
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamWrapper
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.prototype import (
... ProbabilisticTensorDictModule,
... ProbabilisticTensorDictSequential,
Expand All @@ -161,9 +161,9 @@ to build distributions from network outputs and get summary statistics or sample
>>> td = TensorDict(
... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> net = torch.nn.Sequential(torch.nn.GRUCell(4, 8), NormalParamExtractor())
>>> module = TensorDictModule(
... NormalParamWrapper(net), in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... net, in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
... in_keys=["loc", "scale"],
Expand Down Expand Up @@ -194,6 +194,7 @@ to build distributions from network outputs and get summary statistics or sample
TensorDictModuleBase
TensorDictModule
ProbabilisticTensorDictModule
ProbabilisticTensorDictSequential
TensorDictSequential
TensorDictModuleWrapper
CudaGraphModule
Expand Down Expand Up @@ -257,6 +258,10 @@ Distributions
NormalParamExtractor
OneHotCategorical
TruncatedNormal
InteractionType
set_interaction_type
add_custom_mapping
mappings


Utils
Expand All @@ -270,8 +275,8 @@ Utils

make_tensordict
dispatch
set_interaction_type
inv_softplus
biased_softplus
set_skip_existing
skip_existing
rand_one_hot
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def _get_pytorch_version(is_nightly, is_local):
# if "PYTORCH_VERSION" in os.environ:
# return f"torch=={os.environ['PYTORCH_VERSION']}"
if is_nightly:
return "torch>=2.6.0.dev"
return "torch>=2.7.0.dev"
if is_local:
return "torch"
return "torch>=2.5.0"
return "torch>=2.6.0"


def _get_packages():
Expand Down
7 changes: 4 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def _new_unsafe(
nested: bool = True,
**kwargs: dict[str, Any] | None,
) -> TensorDict:
if is_compiling():
if is_compiling() and cls is TensorDict:
# If the cls is not TensorDict, we must escape this to keep the same class.
# That's unfortunate because as of now it graph breaks but that's the best we can do.
return TensorDict(
source,
batch_size=batch_size,
Expand Down Expand Up @@ -2195,8 +2197,7 @@ def from_dict_instance(
input_dict = copy(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# TODO: v0.7: remove the None
cur_value = self.get(key, None)
cur_value = self.get(key)
if cur_value is not None:
input_dict[key] = cur_value.from_dict_instance(
value,
Expand Down
65 changes: 33 additions & 32 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __bool__(self):
if "TD_GET_DEFAULTS_TO_NONE" in os.environ:
_GET_DEFAULTS_TO_NONE = strtobool(os.environ["TD_GET_DEFAULTS_TO_NONE"])
else:
_GET_DEFAULTS_TO_NONE = None
_GET_DEFAULTS_TO_NONE = True


def set_get_defaults_to_none(set_to_none: bool = True):
Expand All @@ -172,7 +172,7 @@ def set_get_defaults_to_none(set_to_none: bool = True):

"""
global _GET_DEFAULTS_TO_NONE
_GET_DEFAULTS_TO_NONE = set_to_none
_GET_DEFAULTS_TO_NONE = bool(set_to_none)


def get_defaults_to_none(set_to_none: bool = True):
Expand Down Expand Up @@ -6390,60 +6390,39 @@ def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType:
Args:
key (str, tuple of str): key to be queried. If tuple of str it is
equivalent to chained calls of getattr.
default: default value if the key is not found in the tensordict.
default: default value if the key is not found in the tensordict. Defaults to ``None``.

.. warning::
Currently, if a key is not present in the tensordict and no default
is passed, a `KeyError` is raised. From v0.7, this behaviour will be changed
and a `None` value will be returned instead. To adopt the new behaviour,
set the environment variable `export TD_GET_DEFAULTS_TO_NONE='1'` or call
:func`~tensordict.set_get_defaults_to_none`.
Previously, if a key was not present in the tensordict and no default
was passed, a `KeyError` was raised. From v0.7, this behaviour has been changed
and a `None` value is returned instead (in accordance with the what dict.get behavior).
To adopt the old behavior, set the environment variable `export TD_GET_DEFAULTS_TO_NONE='0'` or call
:func`~tensordict.set_get_defaults_to_none(False)`.

Examples:
>>> td = TensorDict({"x": 1}, batch_size=[])
>>> td.get("x")
tensor(1)
>>> set_get_defaults_to_none(False) # Current default behaviour
>>> td.get("y") # Raises KeyError
>>> set_get_defaults_to_none(True)
>>> td.get("y")
None
"""
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))
# Find what the default is
has_default = False
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
has_default = True
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
has_default = True
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT
try:
return self._get_tuple(key, default=default)
except KeyError:
if _GET_DEFAULTS_TO_NONE is None and not has_default:
# We raise an exception AND a warning because we want the user to know that this exception will
# not be raised in the future
warnings.warn(
f"The entry ({key}) you have queried with `get` is not present in the tensordict. "
"Currently, this raises an exception. "
"To align with `dict.get`, this behaviour will be changed in v0.7 and a `None` value will "
"be returned instead (no error will be raised). "
"To suppress this warning and use the new behaviour (recommended), call `tensordict.set_get_defaults_to_none(True)` or set the env variable `export TD_GET_DEFAULTS_TO_NONE='1'`. "
"To suppress this warning and keep the old behaviour, call `tensordict.set_get_defaults_to_none(False)` or set the env variable `export TD_GET_DEFAULTS_TO_NONE='0'`.",
category=DeprecationWarning,
)
raise
return self._get_tuple(key, default=default)

@abc.abstractmethod
def _get_str(self, key, default): ...
Expand Down Expand Up @@ -11119,6 +11098,8 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
# During exit, updates mustn't be made in-place as the source and dest
# storage location can be identical, resulting in a RuntimeError
if is_compiling():
self.clear_refs_for_compile_()
if exc_type is not None and issubclass(exc_type, Exception):
return False
_last_op = self._last_op_queue.pop()
Expand All @@ -11128,11 +11109,27 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# added or deleted
_inv_caller = LAST_OP_MAPS.get(last_op)
if _inv_caller is not None:
return _inv_caller(self, args, kwargs, out_wr())
prev_ref = out_wr()
return _inv_caller(self, args, kwargs, prev_ref)
else:
raise NotImplementedError(f"Unrecognised function {last_op}.")
return self

def clear_refs_for_compile_(self) -> T:
"""Clears the weakrefs in order for the tensordict to get out of the compile region safely.

Use this whenever you hit `torch._dynamo.exc.Unsupported: reconstruct: WeakRefVariable()`
before returning a TensorDict.

Returns: self
"""
self._last_op = None
for v in self.values(True, True, is_leaf=_is_tensor_collection):
if _is_tensorclass(type(v)):
v = v._tensordict
v._last_op = None
return self

# Clone, select, exclude, empty
def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -> T:
"""Selects the keys of the tensordict and returns a new tensordict with only the selected keys.
Expand Down Expand Up @@ -11559,7 +11556,11 @@ def from_any(
device=device,
batch_size=batch_size,
)
if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"):
if (
isinstance(obj, np.ndarray)
and hasattr(obj.dtype, "names")
and obj.dtype.names is not None
):
return cls.from_struct_array(
obj,
auto_batch_size=auto_batch_size,
Expand Down
7 changes: 2 additions & 5 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,14 +1084,12 @@ def forward(
if self._kwargs is not None:
kwargs.update(
{
# TODO: v0.7: remove the None
kwarg: tensordict.get(in_key, None)
kwarg: tensordict.get(in_key)
for kwarg, in_key in _zip_strict(self._kwargs, self.in_keys)
}
)
tensors = ()
else:
# TODO: v0.7: remove the None
tensors = tuple(
tensordict._get_tuple_maybe_non_tensor(
_unravel_key_to_tuple(in_key), None
Expand Down Expand Up @@ -1121,8 +1119,7 @@ def forward(
keys = unravel_key_list(list(tensors.keys()))
values = tensors.values()
tensors = dict(_zip_strict(keys, values))
# TODO: v0.7: remove the None
tensors = tuple(tensors.get(key, None) for key in self.out_keys)
tensors = tuple(tensors.get(key) for key in self.out_keys)
if not isinstance(tensors, tuple):
tensors = (tensors,)
tensordict_out = self._write_to_tensordict(
Expand Down
1 change: 0 additions & 1 deletion tensordict/nn/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AddStateIndependentNormalScale,
Delta,
NormalParamExtractor,
NormalParamWrapper,
)
from tensordict.nn.distributions.discrete import OneHotCategorical, rand_one_hot
from tensordict.nn.distributions.truncated_normal import TruncatedNormal
Expand Down
6 changes: 2 additions & 4 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def __init__(
else:
write_name = name_unravel
name = name_unravel
# TODO: v0.7: remove the None
dist_params = params.get(name, None)
dist_params = params.get(name)
kwargs = extra_kwargs.get(name, {})
if dist_params is None:
raise KeyError
Expand Down Expand Up @@ -587,8 +586,7 @@ def icdf(self, sample: TensorDictBase) -> TensorDictBase:
KeyError: If neither `<sample_name>` nor `<sample_name>_cdf` can be found in the input TensorDict for a component distribution.
"""
for name, dist in self.dists.items():
# TODO: v0.7: remove the None
prob = sample.get(_add_suffix(name, "_cdf"), None)
prob = sample.get(_add_suffix(name, "_cdf"))
if prob is None:
try:
prob = self.cdf(sample.get(name))
Expand Down
Loading

0 comments on commit 9197097

Please sign in to comment.