Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into stack-non-tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 25, 2024
2 parents 7104a0b + b5f6c17 commit 6ca5e95
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 91 deletions.
29 changes: 24 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _to_module(
if not use_state_dict and isinstance(module, TensorDictBase):
if return_swap:
swap = module.copy()
module.update(self)
module._param_td = getattr(self, "_param_td", self)
return swap
else:
module.update(self)
Expand Down Expand Up @@ -407,7 +407,17 @@ def convert_type(x, y):
continue
child = __dict__["_modules"][key]
local_out = memo.get(id(child), NO_DEFAULT)

if local_out is NO_DEFAULT:
# if isinstance(child, TensorDictBase):
# # then child is a TensorDictParams
# from tensordict.nn import TensorDictParams
#
# local_out = child
# if not isinstance(value, TensorDictParams):
# value = TensorDictParams(value, no_convert=True)
# __dict__["_modules"][key] = value
# else:
local_out = value._to_module(
child,
inplace=inplace,
Expand Down Expand Up @@ -824,12 +834,14 @@ def _index_tensordict(
raise RuntimeError(
f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}."
)
if names is None:
names = self._get_names_idx(index)
if new_batch_size is not None:
batch_size = new_batch_size
else:
batch_size = _getitem_batch_size(batch_size, index)

if names is None:
names = self._get_names_idx(index)

source = {}
for key, item in self.items():
if isinstance(item, TensorDict):
Expand Down Expand Up @@ -1366,14 +1378,20 @@ def is_boolean(idx):
# this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3]
count = 0
idx_to_take = []
no_more_tensors = False
for _idx in idx_names:
if _idx is None:
idx_to_take.append(None)
elif _is_number(_idx):
count += 1
elif isinstance(_idx, (torch.Tensor, np.ndarray)):
idx_to_take.extend([count] * _idx.ndim)
count += 1
if not no_more_tensors:
idx_to_take.extend([count] * _idx.ndim)
count += 1
no_more_tensors = True
else:
# skip this one
count += 1
else:
idx_to_take.append(count)
count += 1
Expand All @@ -1387,6 +1405,7 @@ def names(self, value):
self._rename_subtds(value)
self._erase_names()
return
value = list(value)
num_none = sum(v is None for v in value)
if num_none:
num_none -= 1
Expand Down
6 changes: 4 additions & 2 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def _gather_tensor(tensor, dest=None):
return out

if out is None:
names = input.names if input._has_names() else None

if len(index.shape) == input.ndim and input._has_names():
names = input.names
else:
names = None
return TensorDict(
{
key: _gather_tensor(value)
Expand Down
136 changes: 90 additions & 46 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

from typing import Sequence

import torch
Expand Down Expand Up @@ -77,7 +79,8 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T:

def pad_sequence(
list_of_tensordicts: Sequence[T],
batch_first: bool = True,
batch_first: bool | None = None,
pad_dim: int = 0,
padding_value: float = 0.0,
out: T | None = None,
device: DeviceType | None = None,
Expand All @@ -87,76 +90,117 @@ def pad_sequence(
Args:
list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack.
batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`.
Defaults to ``True``.
pad_dim (int, optional): the ``pad_dim`` indicates the dimension to pad all the keys in the tensordict.
Defaults to ``0``.
padding_value (number, optional): the padding value. Defaults to ``0.0``.
out (TensorDictBase, optional): if provided, the destination where the data will be
written.
device (device compatible type, optional): if provded, the device where the
TensorDict output will be created.
return_mask (bool, optional): if ``True``, a "mask" entry will be returned.
It contains the mask of valid values in the stacked tensordict.
return_mask (bool, optional): if ``True``, a "masks" entry will be returned.
It contains a tensordict with the same structure as the stacked tensordict where every entry contains the mask of valid values with size ``torch.Size([stack_len, *new_shape])``,
where `new_shape[pad_dim] = max_seq_length` and the rest of the `new_shape` matches the previous shape of the contained tensors.
Examples:
>>> list_td = [
... TensorDict({"a": torch.zeros((3,))}, []),
... TensorDict({"a": torch.zeros((4,))}, []),
... TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
... TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
... ]
>>> padded_td = pad_sequence(list_td)
>>> padded_td = pad_sequence(list_td, return_mask=True)
>>> print(padded_td)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False),
masks: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False),
b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
"""
if batch_first is not None:
warnings.warn(
"The batch_first argument is deprecated and will be removed in a future release. The output will always be batch_first.",
category=DeprecationWarning,
)

if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")

# check that all tensordict match
if return_mask:
list_of_tensordicts = [
td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool))
for td in list_of_tensordicts
]
update_batch_size = True
max_seq_length = float("-inf")
keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)
shape = max(len(td) for td in list_of_tensordicts)
if shape == 0:
shape = [
len(list_of_tensordicts),
]
elif batch_first:
shape = [len(list_of_tensordicts), shape]
else:
shape = [shape, len(list_of_tensordicts)]
if out is None:
out = TensorDict(
{}, batch_size=torch.Size(shape), device=device, _run_checks=False
)
tmp_list_of_tensordicts = []
for td in list_of_tensordicts:

if return_mask:
tmp_list_of_tensordicts.append(td.clone(False))

for key in keys:
try:
out.set(
key,
torch.nn.utils.rnn.pad_sequence(
[td.get(key) for td in list_of_tensordicts],
batch_first=batch_first,
padding_value=padding_value,
),
tensor_shape = td.get(key).shape
pos_pad_dim = pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim

# track the maximum sequence length to update batch_size accordingly
if tensor_shape[pos_pad_dim] > max_seq_length:
max_seq_length = tensor_shape[pos_pad_dim]

# The mask should always contain the batch_size of the TensorDict
mask_shape = td.shape

# if the pad_dim is past the batch_size of the TensorDict, we need to add the new dimension to the mask
if pos_pad_dim >= td.ndim:
mask_shape += torch.Size([tensor_shape[pos_pad_dim]])
update_batch_size = False

if return_mask:
tmp_list_of_tensordicts[-1].set(
("masks", key),
torch.ones(mask_shape, dtype=torch.bool),
)
except Exception as err:
raise RuntimeError(f"pad_sequence failed for key {key}") from err
return out
else:
for key in keys:
out.set_(
if return_mask:
list_of_tensordicts = tmp_list_of_tensordicts

keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)

old_batch_size = list(list_of_tensordicts[0].batch_size)
if update_batch_size and len(old_batch_size) > 0:
old_batch_size[pad_dim] = max_seq_length
shape = [
len(list_of_tensordicts),
] + old_batch_size

if out is None:
out = list_of_tensordicts[0].empty(recurse=True).reshape(torch.Size(shape))

for key in keys:
try:
tensor_shape = list_of_tensordicts[0].get(key).shape
pos_pad_dim = (
(pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim)
if len(tensor_shape) > 1
else 0 # handles the case when the masks are 1-dimensional
)
out.set(
key,
torch.nn.utils.rnn.pad_sequence(
[td.get(key) for td in list_of_tensordicts],
batch_first=batch_first,
[
td.get(key).transpose(0, pos_pad_dim)
for td in list_of_tensordicts
],
batch_first=True,
padding_value=padding_value,
),
).transpose(1, pos_pad_dim + 1),
inplace=True,
)
return out
except Exception as err:
raise RuntimeError(f"pad_sequence failed for key {key}") from err
return out


def merge_tensordicts(*tensordicts: T) -> T:
Expand Down
23 changes: 11 additions & 12 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,18 +990,17 @@ def _load_from_state_dict(
unexpected_keys,
error_msgs,
):
data = (
TensorDict(
{
key: val
for key, val in state_dict.items()
if key.startswith(prefix) and val is not None
},
[],
)
.unflatten_keys(".")
.get(prefix[:-1])
)
data = TensorDict(
{
key: val
for key, val in state_dict.items()
if key.startswith(prefix) and val is not None
},
[],
).unflatten_keys(".")
prefix = tuple(key for key in prefix.split(".") if key)
if prefix:
data = data.get(prefix)
self.data.load_state_dict(data)

def items(
Expand Down
11 changes: 11 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def __torch_function__(
cls.load_state_dict = _load_state_dict
cls._memmap_ = _memmap_

cls.__enter__ = __enter__
cls.__exit__ = __exit__

# Memmap
cls.memmap_like = TensorDictBase.memmap_like
cls.memmap_ = TensorDictBase.memmap_
Expand Down Expand Up @@ -425,6 +428,14 @@ def _load_memmap(cls, prefix: Path, metadata: dict):
return cls._from_tensordict(td, non_tensordict)


def __enter__(self, *args, **kwargs):
return self._tensordict.__enter__(*args, **kwargs)


def __exit__(self, *args, **kwargs):
return self._tensordict.__exit__(*args, **kwargs)


def _getstate(self) -> dict[str, Any]:
"""Returns a state dict which consists of tensor and non_tensor dicts for serialization.
Expand Down
5 changes: 4 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ def expand_as_right(
f" tensor.ndimension()={tensor.ndimension()} and "
f"dest.ndimension()={dest.ndimension()}"
)
if not (tensor.shape == dest.shape[: tensor.ndimension()]):
if any(
tensor.shape[i] != dest.shape[i] and tensor.shape[i] != 1
for i in range(tensor.ndimension())
):
raise RuntimeError(
f"tensor shape is incompatible with dest shape, "
f"got: tensor.shape={tensor.shape}, dest={dest.shape}"
Expand Down
Loading

0 comments on commit 6ca5e95

Please sign in to comment.