From 8cd60091cd87d6e45ecf9ed65b2b4d79ad16ac34 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 12:28:59 +0000 Subject: [PATCH] init --- tensordict/base.py | 17 ++++++++++++----- tensordict/utils.py | 2 ++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 883fe5d9a..c8c458766 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3179,8 +3179,12 @@ def map( num_workers (int, optional): the number of workers. Exclusive with ``pool``. If none is provided, the number of workers will be set to the number of cpus available. - chunksize (int, optional): The size of each chunk of data. If none - is provided, the number of chunks will equate the number + chunksize (int, optional): The size of each chunk of data. + A ``chunksize`` of 0 will unbind the tensordict along the + desired dimension and restack it after the function is applied, + whereas ``chunksize>0`` will split the tensordict and call + :func:`torch.cat` on the resulting list of tensordicts. + If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically @@ -3274,9 +3278,12 @@ def map( raise ValueError(f"Got incompatible dimension {dim_orig}") self_split = _split_tensordict(self, chunksize, num_chunks, num_workers, dim) - chunksize = 1 - imap = pool.imap(fn, self_split, chunksize) - out = torch.cat(list(imap), dim) + call_chunksize = 1 + imap = pool.imap(fn, self_split, call_chunksize) + if chunksize == 0: + out = torch.stack(list(imap), dim) + else: + out = torch.cat(list(imap), dim) return out # Functorch compatibility diff --git a/tensordict/utils.py b/tensordict/utils.py index 4c8308be8..2e69f88c4 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1153,6 +1153,8 @@ def _split_tensordict(td, chunksize, num_chunks, num_workers, dim): num_chunks = min(td.shape[dim], num_chunks) return td.chunk(num_chunks, dim=dim) else: + if chunksize == 0: + return td.unbind(dim=dim) chunksize = min(td.shape[dim], chunksize) return td.split(chunksize, dim=dim)