From 0bdcb965160620fca1bf4069368a712504be5e7e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:02:51 +0100 Subject: [PATCH] [BugFix] Compatibility with non-tensor inputs in CudaGraphModule ghstack-source-id: 3eff6c24b0fa381665823deb5a1efdb7d2cc1bd2 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1039 --- tensordict/nn/cudagraphs.py | 41 +++++++++++++++++++++++++++++-------- test/test_compile.py | 13 ++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index d6eefe1eb..6bb4a7837 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -28,9 +28,9 @@ from torch.utils._pytree import SUPPORTED_NODES, tree_map try: - from torch.utils._pytree import tree_leaves + from torch.utils._pytree import tree_flatten, tree_leaves, tree_unflatten except ImportError: - from torch.utils._pytree import tree_flatten + from torch.utils._pytree import tree_flatten, tree_unflatten def tree_leaves(pytree): """Torch 2.0 compatible version of tree_leaves.""" @@ -293,11 +293,13 @@ def check_tensor_id(name, t0, t1): def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter >= self._warmup: - tree_map( - lambda x, y: x.copy_(y, non_blocking=True), - (self._args, self._kwargs), - (args, kwargs), - ) + srcs, dests = [], [] + for arg_src, arg_dest in zip( + tree_leaves((args, kwargs)), self._flat_tree + ): + self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests) + if dests: + torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() if self._return_unchanged == "clone": @@ -322,8 +324,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): self.counter += self._has_cuda return out else: - args, kwargs = self._args, self._kwargs = tree_map( - self._check_device_and_clone, (args, kwargs) + self._flat_tree, self._tree_spec = tree_flatten((args, kwargs)) + + self._flat_tree = tuple( + self._check_device_and_clone(arg) for arg in self._flat_tree + ) + args, kwargs = self._args, self._kwargs = tree_unflatten( + self._flat_tree, self._tree_spec ) torch.cuda.synchronize() @@ -360,6 +367,22 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func + @staticmethod + def _maybe_copy_onto_(src, dest, srcs, dests): + if isinstance(src, (torch.Tensor, TensorDictBase)): + srcs.append(src) + dests.append(dest) + return + try: + if src != dest: + raise ValueError("Varying inputs must be torch.Tensor subclasses.") + except Exception: + raise RuntimeError( + "Couldn't assess input value. Make sure your function only takes tensor inputs or that " + "the input value can be easily checked and is constant. For a better efficiency, avoid " + "passing non-tensor inputs to your function." + ) + @classmethod def _check_device_and_clone(cls, x): if isinstance(x, torch.Tensor) or is_tensor_collection(x): diff --git a/test/test_compile.py b/test/test_compile.py index 755a928e2..ff4e79f38 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -1056,6 +1056,19 @@ def test_td_input_non_tdmodule(self, compiled): if i == 5: assert not func._is_tensordict_module + def test_td_input_non_tdmodule_nontensor(self, compiled): + func = lambda x, y: x + y + func = self._make_cudagraph(func, compiled) + for i in range(10): + assert func(torch.zeros(()), 1.0) == 1.0 + if i == 5: + assert not func._is_tensordict_module + if torch.cuda.is_available(): + with pytest.raises( + ValueError, match="Varying inputs must be torch.Tensor subclasses." + ): + func(torch.zeros(()), 2.0) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()