Skip to content

Commit

Permalink
[Feature] chunk and unbind for memory mapped tensors (#646)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 31, 2024
1 parent c84b40e commit fdcc403
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 19 deletions.
34 changes: 28 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import collections
import concurrent.futures
import contextlib
import importlib
import json
import numbers
import warnings
Expand Down Expand Up @@ -3806,6 +3807,7 @@ def map(
generator: torch.Generator | None = None,
max_tasks_per_child: int | None = None,
worker_threads: int = 1,
pbar: bool = False,
):
"""Maps a function to splits of the tensordict across one dimension.
Expand Down Expand Up @@ -3872,6 +3874,8 @@ def map(
on the number of jobs.
worker_threads (int, optional): the number of threads for the workers.
Defaults to ``1``.
pbar (bool, optional): if ``True``, a progress bar will be displayed.
Requires tqdm to be available. Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -3916,7 +3920,12 @@ def map(
maxtasksperchild=max_tasks_per_child,
) as pool:
return self.map(
fn, dim=dim, chunksize=chunksize, num_chunks=num_chunks, pool=pool
fn,
dim=dim,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
pbar=pbar,
)
num_workers = pool._processes
dim_orig = dim
Expand All @@ -3928,11 +3937,24 @@ def map(
self_split = _split_tensordict(self, chunksize, num_chunks, num_workers, dim)
call_chunksize = 1
imap = pool.imap(fn, self_split, call_chunksize)
if chunksize == 0:
out = torch.stack(list(imap), dim)
else:
out = torch.cat(list(imap), dim)
return out
if pbar and importlib.util.find_spec("tqdm", None) is not None:
import tqdm

imap = tqdm.tqdm(imap, total=len(self_split))

imaplist = []
for item in imap:
if item is not None:
imaplist.append(item)
del imap

# support inplace modif
if imaplist:
if chunksize == 0:
out = torch.stack(imaplist, dim)
else:
out = torch.cat(imaplist, dim)
return out

# Functorch compatibility
@abc.abstractmethod
Expand Down
67 changes: 56 additions & 11 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

import functools

import mmap
import os

Expand All @@ -13,7 +15,7 @@
from multiprocessing import util
from multiprocessing.context import reduction
from pathlib import Path
from typing import Any, overload
from typing import Any, Callable, overload

import numpy as np
import torch
Expand Down Expand Up @@ -569,11 +571,7 @@ def __getitem__(self, item):
) from err
raise
if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
out = MemoryMappedTensor(out)
out._handler = self._handler
out._filename = self._filename
out.index = item
out.parent_shape = self.parent_shape
out = self._index_wrap(out, item)
return out

@implement_for("torch", None, "2.0")
Expand All @@ -588,13 +586,34 @@ def __getitem__(self, item): # noqa: F811
) from err
raise
if out.storage().data_ptr() == self.storage().data_ptr():
out = MemoryMappedTensor(out)
out._handler = self._handler
out._filename = self._filename
out.index = item
out.parent_shape = self.parent_shape
out = self._index_wrap(out, item)
return out

def _index_wrap(self, tensor, item, check=False):
if check:
if tensor.storage().data_ptr() == self.storage().data_ptr():
return self._index_wrap(tensor, item)
return tensor
tensor = MemoryMappedTensor(tensor)
tensor._handler = self._handler
tensor._filename = self._filename
tensor.index = item
tensor.parent_shape = self.parent_shape
return tensor

def unbind(self, dim):
out = super().unbind(dim)
if dim < 0:
dim = self.ndim + dim
index_base = (slice(None),) * dim
return tuple(
self._index_wrap(_out, index_base + (i,)) for i, _out in enumerate(out)
)

def chunk(self, chunks, dim=0):
out = super().chunk(chunks, dim)
return tuple(self._index_wrap(chunk, None, check=True) for chunk in out)


#####################
# File handler
Expand Down Expand Up @@ -709,3 +728,29 @@ def _proc_args_const(*args, **kwargs):
kwargs.pop("fill_value", None),
kwargs.pop("filename", None),
)


# Torch functions

MEMMAP_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}


def implements_for_memmap(torch_function: Callable) -> Callable[[Callable], Callable]:
"""Register a torch function override for MemoryMappedTensor."""

@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
MEMMAP_HANDLED_FUNCTIONS[torch_function] = func
return func

return decorator


@implements_for_memmap(torch.unbind)
def _unbind(tensor, dim):
return tensor.unbind(dim)


@implements_for_memmap(torch.chunk)
def _chunk(input, chunks, dim=0):
return input.chunk(chunks, dim=dim)
3 changes: 3 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def map(
chunksize: int = None,
num_chunks: int = None,
pool: mp.Pool = None,
generator: torch.Generator | None = None,
max_tasks_per_child: int | None = None,
worker_threads: int = 1,
):
raise RuntimeError(
"Cannot call map on a TensorDictParams object. Convert it "
Expand Down
20 changes: 19 additions & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_KEY_ERROR,
_LOCK_ERROR,
_parse_to,
_proc_init,
_split_tensordict,
cache,
expand_right,
Expand Down Expand Up @@ -684,11 +685,28 @@ def map(
chunksize: int = None,
num_chunks: int = None,
pool: mp.Pool = None,
generator: torch.Generator | None = None,
max_tasks_per_child: int | None = None,
worker_threads: int = 1,
):
if pool is None:
if num_workers is None:
num_workers = mp.cpu_count() # Get the number of CPU cores
with mp.Pool(num_workers) as pool:
if generator is None:
generator = torch.Generator()
seed = (
torch.empty((), dtype=torch.int64).random_(generator=generator).item()
)

queue = mp.Queue(maxsize=num_workers)
for i in range(num_workers):
queue.put(i)
with mp.Pool(
processes=num_workers,
initializer=_proc_init,
initargs=(seed, queue, worker_threads),
maxtasksperchild=max_tasks_per_child,
) as pool:
return self.map(fn, dim=dim, chunksize=chunksize, pool=pool)
num_workers = pool._processes
dim_orig = dim
Expand Down
2 changes: 1 addition & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,7 +1753,7 @@ def _legacy_lazy(func):

# Process initializer for map
def _proc_init(base_seed, queue, num_threads):
worker_id = queue.get(timeout=10)
worker_id = queue.get(timeout=120)
seed = base_seed + worker_id
torch.manual_seed(seed)
np_seed = _generate_state(base_seed, worker_id)
Expand Down
13 changes: 13 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7479,6 +7479,19 @@ def test_map_unbind(self):
assert td_out[1]["1"] == 1
assert (td_out["2"] == 2).all()

@staticmethod
def _assert_is_memmap(data):
assert isinstance(data["tensor"], MemoryMappedTensor)

@pytest.mark.parametrize("chunksize", [0, 5])
def test_map_inplace(self, chunksize):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
# Tests that we can return None values
# Also tests that MemoryMapped id is kept using multiprocessing
data = TensorDict({"tensor": torch.zeros(10)}, [10]).memmap_()
data.map(self._assert_is_memmap, chunksize=chunksize, num_workers=2)


# class TestNonTensorData:
class TestNonTensorData:
Expand Down

0 comments on commit fdcc403

Please sign in to comment.