Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 9, 2024
1 parent d3e2d5f commit 9a42112
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 105 deletions.
11 changes: 9 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,9 +1235,13 @@ def _add_batch_dim(self, *, in_dim, vmap_level):
else:
in_dim = in_dim - 1
stack_dim = td.stack_dim

def addbatchdim(_arg):
return _add_batch_dim(_arg, in_dim, vmap_level)

tds = [
td._fast_apply(
lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level),
addbatchdim,
batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim],
names=(
[name for i, name in enumerate(td.names) if i != in_dim]
Expand Down Expand Up @@ -3568,7 +3572,10 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def contiguous(self) -> T:
return self._fast_apply(lambda x: x.contiguous(), propagate_lock=True)
def contiguous(x):
return x.contiguous()

return self._fast_apply(contiguous, propagate_lock=True)

def rename_key_(
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
Expand Down
53 changes: 41 additions & 12 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,19 +1871,27 @@ def repeat_interleave(
for i, s in enumerate(self.batch_size)
]
)
return self._fast_apply(
lambda leaf: leaf.repeat_interleave(

def rep(leaf):
return leaf.repeat_interleave(
repeats=repeats, dim=dim_corrected, output_size=output_size
),
)

return self._fast_apply(
rep,
batch_size=new_batch_size,
call_on_nested=True,
propagate_lock=True,
)

def _repeat(self, *repeats: int) -> TensorDictBase:
new_batch_size = torch.Size([i * r for i, r in zip(self.batch_size, repeats)])

def rep(leaf):
return leaf.repeat(*repeats, *((1,) * (leaf.ndim - self.ndim)))

return self._fast_apply(
lambda leaf: leaf.repeat(*repeats, *((1,) * (leaf.ndim - self.ndim))),
rep,
batch_size=new_batch_size,
call_on_nested=True,
propagate_lock=True,
Expand Down Expand Up @@ -2008,8 +2016,11 @@ def _squeeze(tensor):
if names:
names.pop(dim)

def squeeze(x):
return x.squeeze(newdim)

result = self._fast_apply(
lambda x: x.squeeze(newdim),
squeeze,
batch_size=batch_size,
names=names,
inplace=False,
Expand Down Expand Up @@ -3362,9 +3373,13 @@ def keys(
if not sort:
return _StringKeys(self._tensordict.keys())
else:

def keyfunc(x):
return ".".join(x) if isinstance(x, tuple) else x

return sorted(
_StringKeys(self._tensordict.keys()),
key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
key=keyfunc,
)
else:
return self._nested_keys(
Expand Down Expand Up @@ -3403,7 +3418,11 @@ def items(
if not include_nested and not leaves_only:
if not sort:
return self._tensordict.items()
return sorted(self._tensordict.items(), key=lambda x: x[0])

def keyfunc(x):
return x[0]

return sorted(self._tensordict.items(), key=keyfunc)
elif include_nested and leaves_only and not sort:
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
result = []
Expand Down Expand Up @@ -3449,9 +3468,11 @@ def values(
if not sort:
return self._tensordict.values()
else:
return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[
1
]

def keyfunc(x):
return x[0]

return list(zip(*sorted(self._tensordict.items(), key=keyfunc)))[1]
else:
return TensorDictBase.values(
self,
Expand Down Expand Up @@ -4061,8 +4082,12 @@ def expand(self, *args: int, inplace: bool = False) -> T:
shape = tuple(args[0])
else:
shape = args

def expand(x):
return x.expand((*shape, *x.shape[self.ndim :]))

return self._fast_apply(
lambda x: x.expand((*shape, *x.shape[self.ndim :])),
expand,
batch_size=shape,
propagate_lock=True,
)
Expand Down Expand Up @@ -4432,9 +4457,13 @@ def _iter():
)

if self.sort:

def keyfunc(key):
return ".".join(key) if isinstance(key, tuple) else key

yield from sorted(
_iter(),
key=lambda key: ".".join(key) if isinstance(key, tuple) else key,
key=keyfunc,
)
else:
yield from _iter()
Expand Down
30 changes: 24 additions & 6 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ def _gather_tensor(tensor, dest_container=None, dest_key=None):

@implements_for_td(torch.full_like)
def _full_like(td: T, fill_value: float, *args, **kwargs: Any) -> T:
def full_like(x):
return torch.full_like(x, fill_value, *args, **kwargs)

return td._fast_apply(
lambda x: torch.full_like(x, fill_value, *args, **kwargs),
full_like,
inplace=True,
propagate_lock=True,
device=kwargs.get("device", NO_DEFAULT),
Expand All @@ -154,8 +157,11 @@ def _full_like(td: T, fill_value: float, *args, **kwargs: Any) -> T:

@implements_for_td(torch.zeros_like)
def _zeros_like(td: T, *args, **kwargs: Any) -> T:
def zeros_like(x):
return torch.zeros_like(x, *args, **kwargs)

td_clone = td._fast_apply(
lambda x: torch.zeros_like(x, *args, **kwargs),
zeros_like,
propagate_lock=True,
device=kwargs.get("device", NO_DEFAULT),
)
Expand All @@ -173,8 +179,11 @@ def _zeros_like(td: T, *args, **kwargs: Any) -> T:

@implements_for_td(torch.ones_like)
def _ones_like(td: T, *args, **kwargs: Any) -> T:
def ones_like(x):
return torch.ones_like(x, *args, **kwargs)

td_clone = td._fast_apply(
lambda x: torch.ones_like(x, *args, **kwargs),
ones_like,
propagate_lock=True,
device=kwargs.get("device", NO_DEFAULT),
)
Expand All @@ -190,8 +199,11 @@ def _ones_like(td: T, *args, **kwargs: Any) -> T:

@implements_for_td(torch.rand_like)
def _rand_like(td: T, *args, **kwargs: Any) -> T:
def rand_like(x):
return torch.rand_like(x, *args, **kwargs)

td_clone = td._fast_apply(
lambda x: torch.rand_like(x, *args, **kwargs),
rand_like,
propagate_lock=True,
device=kwargs.get("device", NO_DEFAULT),
)
Expand All @@ -207,8 +219,11 @@ def _rand_like(td: T, *args, **kwargs: Any) -> T:

@implements_for_td(torch.randn_like)
def _randn_like(td: T, *args, **kwargs: Any) -> T:
def randn_like(x):
return torch.randn_like(x, *args, **kwargs)

td_clone = td._fast_apply(
lambda x: torch.randn_like(x, *args, **kwargs),
randn_like,
propagate_lock=True,
device=kwargs.get("device", NO_DEFAULT),
)
Expand All @@ -224,8 +239,11 @@ def _randn_like(td: T, *args, **kwargs: Any) -> T:

@implements_for_td(torch.empty_like)
def _empty_like(td: T, *args, **kwargs) -> T:
def empty_like(x):
return torch.empty_like(x, *args, **kwargs)

return td._fast_apply(
lambda x: torch.empty_like(x, *args, **kwargs),
empty_like,
propagate_lock=True,
device=kwargs.get("device", NO_DEFAULT),
)
Expand Down
Loading

0 comments on commit 9a42112

Please sign in to comment.