diff --git a/tensordict/base.py b/tensordict/base.py index 6c19b852e..af7b7fec3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -18,6 +18,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import copy +from functools import wraps from pathlib import Path from textwrap import indent from typing import ( @@ -38,6 +39,7 @@ import numpy as np import torch from tensordict.utils import ( + _CloudpickleWrapper, _GENERIC_NESTED_ERR, _get_shape_from_args, _is_tensorclass, @@ -3878,6 +3880,8 @@ def map( fn: Callable, dim: int = 0, num_workers: int | None = None, + *, + out: TensorDictBase = None, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, @@ -3904,6 +3908,15 @@ def map( num_workers (int, optional): the number of workers. Exclusive with ``pool``. If none is provided, the number of workers will be set to the number of cpus available. + + Keyword Args: + out (TensorDictBase, optional): an optional container for the output. + Its batch-size along the ``dim`` provided must match ``self.ndim``. + If it is shared or memmap (:meth:`~.is_shared` or :meth:`~.is_memmap` + returns ``True``) it will be populated within the remote processes, + avoiding data inward transfers. Otherwise, the data from the ``self`` + slice will be sent to the process, collected on the current process + and written inplace into ``out``. chunksize (int, optional): The size of each chunk of data. A ``chunksize`` of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, @@ -4003,6 +4016,7 @@ def map( num_chunks=num_chunks, pool=pool, pbar=pbar, + out=out, ) num_workers = pool._processes dim_orig = dim @@ -4013,16 +4027,44 @@ def map( self_split = _split_tensordict(self, chunksize, num_chunks, num_workers, dim) call_chunksize = 1 + + def wrap_fn_with_out(fn, out): + @wraps(fn) + def newfn(item_and_out): + item, out = item_and_out + result = fn(item) + out.update_(result) + return + + out_split = _split_tensordict(out, chunksize, num_chunks, num_workers, dim) + return _CloudpickleWrapper(newfn), zip(self_split, out_split) + + if out is not None and (out.is_shared() or out.is_memmap()): + fn, self_split = wrap_fn_with_out(fn, out) + out = None + imap = pool.imap(fn, self_split, call_chunksize) + if pbar and importlib.util.find_spec("tqdm", None) is not None: import tqdm imap = tqdm.tqdm(imap, total=len(self_split)) imaplist = [] + start = 0 for item in imap: if item is not None: - imaplist.append(item) + if out is not None: + if chunksize: + end = start + item.shape[dim] + chunk = slice(start, end) + out[chunk].update_(item) + start = end + else: + out[start].update_(item) + start += 1 + else: + imaplist.append(item) del imap # support inplace modif @@ -4031,7 +4073,7 @@ def map( out = torch.stack(imaplist, dim) else: out = torch.cat(imaplist, dim) - return out + return out # Functorch compatibility @abc.abstractmethod diff --git a/tensordict/utils.py b/tensordict/utils.py index 011a09c9b..0559af885 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2034,3 +2034,21 @@ def remove_duplicates( return output, unique_indices return output + + +class _CloudpickleWrapper(object): + def __init__(self, fn): + self.fn = fn + + def __getstate__(self): + import cloudpickle + + return cloudpickle.dumps(self.fn) + + def __setstate__(self, ob: bytes): + import pickle + + self.fn = pickle.loads(ob) + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 10d5f8f26..c402ff67e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,6 +12,7 @@ import os import re import uuid +from pathlib import Path import numpy as np import pytest @@ -7576,6 +7577,23 @@ def test_map_inplace(self, chunksize): data = TensorDict({"tensor": torch.zeros(10)}, [10]).memmap_() data.map(self._assert_is_memmap, chunksize=chunksize, num_workers=2) + @staticmethod + def selectfn(input): + return input.select("a") + + @pytest.mark.parametrize("chunksize", [0, 5]) + @pytest.mark.parametrize("mmap", [True, False]) + def test_map_with_out(self, mmap, chunksize, tmpdir): + tmpdir = Path(tmpdir) + input = TensorDict({"a": torch.arange(10), "b": torch.arange(10)}, [10]) + if mmap: + input.memmap_(tmpdir / "input") + out = TensorDict({"a": torch.zeros(10, dtype=torch.int)}, [10]) + if mmap: + out.memmap_(tmpdir / "output") + input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out) + assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap) + # class TestNonTensorData: class TestNonTensorData: