From 04f96421499cbfaf8f0823f96b0676f502408911 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 18 Jan 2024 09:46:32 +0000 Subject: [PATCH] amend --- tensordict/_td.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensordict/_td.py b/tensordict/_td.py index 1737741f0..1ad770f13 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -835,10 +835,15 @@ def _unbind(self, dim: int): names = [name for i, name in enumerate(names) if i != dim] device = self.device + is_shared = self._is_shared + is_memmap = self._is_memmap + def empty(): result = TensorDict( {}, batch_size=batch_size, names=names, _run_checks=False, device=device ) + result._is_shared = is_shared + result._is_memmap = is_memmap return result tds = tuple(empty() for _ in range(self.batch_size[dim]))