Skip to content

Commit

Permalink
[Feature] map with preallocated output (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 7, 2024
1 parent 74fd67b commit 751091a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
46 changes: 44 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from concurrent.futures import ThreadPoolExecutor
from copy import copy
from functools import wraps
from pathlib import Path
from textwrap import indent
from typing import (
Expand All @@ -38,6 +39,7 @@
import numpy as np
import torch
from tensordict.utils import (
_CloudpickleWrapper,
_GENERIC_NESTED_ERR,
_get_shape_from_args,
_is_tensorclass,
Expand Down Expand Up @@ -3878,6 +3880,8 @@ def map(
fn: Callable,
dim: int = 0,
num_workers: int | None = None,
*,
out: TensorDictBase = None,
chunksize: int | None = None,
num_chunks: int | None = None,
pool: mp.Pool | None = None,
Expand All @@ -3904,6 +3908,15 @@ def map(
num_workers (int, optional): the number of workers. Exclusive with ``pool``.
If none is provided, the number of workers will be set to the
number of cpus available.
Keyword Args:
out (TensorDictBase, optional): an optional container for the output.
Its batch-size along the ``dim`` provided must match ``self.ndim``.
If it is shared or memmap (:meth:`~.is_shared` or :meth:`~.is_memmap`
returns ``True``) it will be populated within the remote processes,
avoiding data inward transfers. Otherwise, the data from the ``self``
slice will be sent to the process, collected on the current process
and written inplace into ``out``.
chunksize (int, optional): The size of each chunk of data.
A ``chunksize`` of 0 will unbind the tensordict along the
desired dimension and restack it after the function is applied,
Expand Down Expand Up @@ -4003,6 +4016,7 @@ def map(
num_chunks=num_chunks,
pool=pool,
pbar=pbar,
out=out,
)
num_workers = pool._processes
dim_orig = dim
Expand All @@ -4013,16 +4027,44 @@ def map(

self_split = _split_tensordict(self, chunksize, num_chunks, num_workers, dim)
call_chunksize = 1

def wrap_fn_with_out(fn, out):
@wraps(fn)
def newfn(item_and_out):
item, out = item_and_out
result = fn(item)
out.update_(result)
return

out_split = _split_tensordict(out, chunksize, num_chunks, num_workers, dim)
return _CloudpickleWrapper(newfn), zip(self_split, out_split)

if out is not None and (out.is_shared() or out.is_memmap()):
fn, self_split = wrap_fn_with_out(fn, out)
out = None

imap = pool.imap(fn, self_split, call_chunksize)

if pbar and importlib.util.find_spec("tqdm", None) is not None:
import tqdm

imap = tqdm.tqdm(imap, total=len(self_split))

imaplist = []
start = 0
for item in imap:
if item is not None:
imaplist.append(item)
if out is not None:
if chunksize:
end = start + item.shape[dim]
chunk = slice(start, end)
out[chunk].update_(item)
start = end
else:
out[start].update_(item)
start += 1
else:
imaplist.append(item)
del imap

# support inplace modif
Expand All @@ -4031,7 +4073,7 @@ def map(
out = torch.stack(imaplist, dim)
else:
out = torch.cat(imaplist, dim)
return out
return out

# Functorch compatibility
@abc.abstractmethod
Expand Down
18 changes: 18 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,3 +2034,21 @@ def remove_duplicates(
return output, unique_indices

return output


class _CloudpickleWrapper(object):
def __init__(self, fn):
self.fn = fn

def __getstate__(self):
import cloudpickle

return cloudpickle.dumps(self.fn)

def __setstate__(self, ob: bytes):
import pickle

self.fn = pickle.loads(ob)

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
18 changes: 18 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import re
import uuid
from pathlib import Path

import numpy as np
import pytest
Expand Down Expand Up @@ -7576,6 +7577,23 @@ def test_map_inplace(self, chunksize):
data = TensorDict({"tensor": torch.zeros(10)}, [10]).memmap_()
data.map(self._assert_is_memmap, chunksize=chunksize, num_workers=2)

@staticmethod
def selectfn(input):
return input.select("a")

@pytest.mark.parametrize("chunksize", [0, 5])
@pytest.mark.parametrize("mmap", [True, False])
def test_map_with_out(self, mmap, chunksize, tmpdir):
tmpdir = Path(tmpdir)
input = TensorDict({"a": torch.arange(10), "b": torch.arange(10)}, [10])
if mmap:
input.memmap_(tmpdir / "input")
out = TensorDict({"a": torch.zeros(10, dtype=torch.int)}, [10])
if mmap:
out.memmap_(tmpdir / "output")
input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out)
assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap)


# class TestNonTensorData:
class TestNonTensorData:
Expand Down

0 comments on commit 751091a

Please sign in to comment.