Skip to content

Commit

Permalink
[BugFix] dynamo compat refactors
Browse files Browse the repository at this point in the history
ghstack-source-id: 7681ecf34d26be5d50f780fe8dd98e1dc7974ec8
Pull Request resolved: #975
  • Loading branch information
vmoens committed Sep 2, 2024
1 parent 0e86141 commit 785cac1
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 222 deletions.
77 changes: 77 additions & 0 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,83 @@ def call_with_backward(*args):
benchmark(call_with_backward, x)


@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_vmap_func_call_cm_runtime(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
module = mlp(device=device, depth=10, num_cells=16, feature_dim=16)
# module = torch.nn.Transformer(16, dim_feedforward=64, device=device)
td = TensorDict.from_module(module)
td = TensorDictParams(td.data.expand(10).clone().zero_())

def call(x, td):
# with needs registering
with td.to_module(module):
return module(x)

call_vmap = torch.vmap(call, (None, 0))
if mode == "compile":
call_vmap = torch.compile(call_vmap)
elif mode == "compile-overhead":
call_vmap = torch.compile(call_vmap, mode="reduce-overhead")

x = torch.randn(2, 2, 16)
call_vmap(x, td)
call_vmap(x, td)
benchmark(call_vmap, x, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("plain_decorator", [None, False, True])
def test_vmap_func_call_runtime_and_backward(mode, plain_decorator, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
module = mlp(device=device, depth=10, num_cells=16, feature_dim=16)
# module = torch.nn.Transformer(16, dim_feedforward=64, device=device)
td = TensorDict.from_module(module)
td = TensorDictParams(td.data.expand(10).clone().zero_())
if not plain_decorator:

def call(x, td):
if torch.cuda.is_available():
torch.compiler.cudagraph_mark_step_begin()
# with needs registering
params = td.to_module(module, return_swap=True)
result = module(x)
params.to_module(module, return_swap=False)
return result

else:

def call(x, td):
if torch.cuda.is_available():
torch.compiler.cudagraph_mark_step_begin()
# with needs registering
with td.to_module(module):
return module(x)

call_vmap = torch.vmap(call, (None, 0))
if mode == "compile":
call_vmap = torch.compile(call_vmap)
elif mode == "compile-overhead":
call_vmap = torch.compile(call_vmap, mode="reduce-overhead")

if mode == "compile":
call_vmap = torch.compile(call_vmap, fullgraph=not plain_decorator)
elif mode == "compile-overhead":
call_vmap = torch.compile(
call_vmap, fullgraph=not plain_decorator, mode="reduce-overhead"
)

def call_with_backward(*args):
call_vmap(*args).mean().backward()

x = torch.randn(2, 2, 16)
call_with_backward(x, td)
call_with_backward(x, td)
benchmark(call_with_backward, x, td)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 1 addition & 3 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
lazy_legacy,
NestedKey,
set_lazy_legacy,
)
from tensordict._pytree import *
from tensordict._C import ( # @manual=//pytorch/tensordict:_C
unravel_key,
unravel_key_list,
)
from tensordict._pytree import *
from tensordict.nn import TensorDictParams

try:
Expand Down
181 changes: 181 additions & 0 deletions tensordict/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib

# This is a copy from https://github.com/pytorch/pytorch/blob/main/torch/utils/_contextlib.py#L120
# We use it for compatibility with torch >= 1.10 where the implementation fails
Expand All @@ -16,6 +17,10 @@
import warnings
from typing import Any, Callable, cast, TypeVar

import numpy as np
from torch.compiler import is_dynamo_compiling


# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
# 'no_grad' and 'enable_grad').
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
Expand Down Expand Up @@ -155,3 +160,179 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
def clone(self):
# override this method if your children class takes __init__ parameters
return type(self)()


# TD cm functions
LAST_OP_MAPS = {}


def _reverse_lock(self, args, kwargs, out):
return self.unlock_()


LAST_OP_MAPS["lock_"] = _reverse_lock


def _reverse_unlock(self, args, kwargs, out):
return self.lock_()


LAST_OP_MAPS["unlock_"] = _reverse_unlock


def _reverse_transpose(self, args, kwargs, out):
dim0, dim1 = args
if not out.is_locked:
return out.update(self.transpose(dim0, dim1), inplace=False)
else:
return out.update_(self.transpose(dim0, dim1))


LAST_OP_MAPS["transpose"] = _reverse_transpose


def _reverse_flatten_keys(self, args, kwargs, out):
sep = args[0] if args else "."
if not out.is_locked:
return out.update(self.unflatten_keys(sep), inplace=False)
else:
return out.update_(self.unflatten_keys(sep))


LAST_OP_MAPS["flatten_keys"] = _reverse_flatten_keys


def _reverse_unflatten_keys(self, args, kwargs, out):
sep = args[0] if args else "."
if not out.is_locked:
return out.update(self.flatten_keys(sep), inplace=False)
else:
return out.update_(self.flatten_keys(sep))


LAST_OP_MAPS["unflatten_keys"] = _reverse_unflatten_keys


def _reverse_flatten(self, args, kwargs, out):
if len(args) == 2:
dim0, dim1 = args
elif len(args) == 1:
dim0 = args[0]
dim1 = kwargs.get("end_dim", -1)
else:
dim0 = kwargs.get("start_dim", 0)
dim1 = kwargs.get("end_dim", -1)
if dim1 < 0:
dim1 = out.ndim + dim1
if dim0 < 0:
dim0 = out.ndim + dim0

if not out.is_locked:
return out.update(
self.unflatten(dim0, out.shape[dim0 : dim1 + 1]), inplace=False
)
else:
return out.update_(self.unflatten(dim0, out.shape[dim0 : dim1 + 1]))


LAST_OP_MAPS["flatten"] = _reverse_flatten


def _reverse_unflatten(self, args, kwargs, out):
if args:
dim0 = args[0]
if len(args) > 1:
unflattened_size = args[1]
else:
unflattened_size = kwargs.get("unflattened_size")
else:
dim0 = kwargs.get("dim")
unflattened_size = kwargs.get("unflattened_size")
if dim0 < 0:
dim0 = out.ndim + dim0
dim1 = dim0 + len(unflattened_size) - 1
if not out.is_locked:
unflattened = self.flatten(dim0, dim1)
return out.update(unflattened, inplace=False)
else:
unflattened = self.flatten(dim0, dim1)
return out.update_(unflattened)


LAST_OP_MAPS["unflatten"] = _reverse_unflatten


def _reverse_permute(self, args, kwargs, out):
from tensordict.utils import _get_shape_from_args

dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs)
dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list]
# inverse map
inv_dims_list = np.argsort(dims_list)
if not out.is_locked:
return out.update(self.permute(inv_dims_list), inplace=False)
else:
return out.update_(self.permute(inv_dims_list))


LAST_OP_MAPS["permute"] = _reverse_permute


def _reverse_view(self, args, kwargs, out):
if not out.is_locked:
return out.update(self.view(out.shape), inplace=False)
else:
return out.update_(self.view(out.shape))


LAST_OP_MAPS["view"] = _reverse_view


def _reverse_unsqueeze(self, args, kwargs, out):
if args:
(dim,) = args
elif kwargs:
dim = kwargs["dim"]
else:
raise RuntimeError(
"Cannot use td.unsqueeze() as a decorator if the dimension is implicit."
)
if not out.is_locked:
return out.update(self.squeeze(dim), inplace=False)
else:
return out.update_(self.squeeze(dim))


LAST_OP_MAPS["unsqueeze"] = _reverse_unsqueeze


def _reverse_squeeze(self, args, kwargs, out):
if args:
(dim,) = args
elif kwargs:
dim = kwargs["dim"]
else:
raise RuntimeError(
"Cannot use td.squeeze() as a decorator if the dimension is implicit."
)
if not out.is_locked:
return out.update(self.unsqueeze(dim), inplace=False)
else:
return out.update_(self.unsqueeze(dim))


LAST_OP_MAPS["squeeze"] = _reverse_squeeze


def _reverse_to_module(self, args, kwargs, out):
try:
with out.unlock_() if not is_dynamo_compiling() else contextlib.nullcontext():
return self.to_module(*args, **kwargs, swap_dest=out)
except AttributeError:
# This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD
raise RuntimeError(
"to_module cannot be used as a decorator when return_swap=False."
)


LAST_OP_MAPS["to_module"] = _reverse_to_module
12 changes: 8 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
_sub_index,
_unravel_key_to_tuple,
_zip_strict,
Buffer,
cache,
convert_ellipsis_to_idx,
DeviceType,
Expand Down Expand Up @@ -103,6 +102,11 @@
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling

try:
from torch.nn.parameter import Buffer
except ImportError:
from tensordict.utils import Buffer

_register_tensor_class(ftdim.Tensor)

__base__setattr__ = torch.nn.Module.__setattr__
Expand Down Expand Up @@ -247,8 +251,8 @@ def __init__(

self._tensordict = _StringOnlyDict()

if names and is_dynamo_compiling():
graph_break()
# if names and is_dynamo_compiling():
# graph_break()
has_device = device is not None
sub_non_blocking = False
call_sync = False
Expand Down Expand Up @@ -2971,7 +2975,7 @@ def _clone(self, recurse: bool = True) -> T:
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
names=copy(self._td_dim_names) if self._has_names() else None,
names=self._maybe_names(),
)
# If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too
# This may be undesirable, not sure if this should be the default behaviour
Expand Down
Loading

0 comments on commit 785cac1

Please sign in to comment.