From 7e45bccff1056be9478f674d285872af477c2333 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:08:21 +0100 Subject: [PATCH] [BugFix] Compatibility with non-tensor inputs in CudaGraphModule ghstack-source-id: f5a48452c26ae0c28399355573fe0458e402574c Pull Request resolved: https://github.com/pytorch/tensordict/pull/1039 --- tensordict/nn/cudagraphs.py | 46 +++++++++++++++++++++++++++++-------- test/test_compile.py | 13 +++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index d6eefe1eb..e99236b48 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -28,9 +28,9 @@ from torch.utils._pytree import SUPPORTED_NODES, tree_map try: - from torch.utils._pytree import tree_leaves + from torch.utils._pytree import tree_flatten, tree_leaves, tree_unflatten except ImportError: - from torch.utils._pytree import tree_flatten + from torch.utils._pytree import tree_flatten, tree_unflatten def tree_leaves(pytree): """Torch 2.0 compatible version of tree_leaves.""" @@ -293,11 +293,13 @@ def check_tensor_id(name, t0, t1): def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter >= self._warmup: - tree_map( - lambda x, y: x.copy_(y, non_blocking=True), - (self._args, self._kwargs), - (args, kwargs), - ) + srcs, dests = [], [] + for arg_src, arg_dest in zip( + tree_leaves((args, kwargs)), self._flat_tree + ): + self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests) + if dests: + torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() if self._return_unchanged == "clone": @@ -322,8 +324,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): self.counter += self._has_cuda return out else: - args, kwargs = self._args, self._kwargs = tree_map( - self._check_device_and_clone, (args, kwargs) + self._flat_tree, self._tree_spec = tree_flatten((args, kwargs)) + + self._flat_tree = tuple( + self._check_device_and_clone(arg) for arg in self._flat_tree + ) + args, kwargs = self._args, self._kwargs = tree_unflatten( + self._flat_tree, self._tree_spec ) torch.cuda.synchronize() @@ -360,6 +367,27 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func + @staticmethod + def _maybe_copy_onto_(src, dest, srcs, dests): + if isinstance(src, torch.Tensor): + srcs.append(src) + dests.append(dest) + return + if is_tensor_collection(src): + dest.copy_(src) + return + isdiff = False + try: + isdiff = src != dest + except Exception as err: + raise RuntimeError( + "Couldn't assess input value. Make sure your function only takes tensor inputs or that " + "the input value can be easily checked and is constant. For a better efficiency, avoid " + "passing non-tensor inputs to your function." + ) from err + if isdiff: + raise ValueError("Varying inputs must be torch.Tensor subclasses.") + @classmethod def _check_device_and_clone(cls, x): if isinstance(x, torch.Tensor) or is_tensor_collection(x): diff --git a/test/test_compile.py b/test/test_compile.py index 755a928e2..ff4e79f38 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -1056,6 +1056,19 @@ def test_td_input_non_tdmodule(self, compiled): if i == 5: assert not func._is_tensordict_module + def test_td_input_non_tdmodule_nontensor(self, compiled): + func = lambda x, y: x + y + func = self._make_cudagraph(func, compiled) + for i in range(10): + assert func(torch.zeros(()), 1.0) == 1.0 + if i == 5: + assert not func._is_tensordict_module + if torch.cuda.is_available(): + with pytest.raises( + ValueError, match="Varying inputs must be torch.Tensor subclasses." + ): + func(torch.zeros(()), 2.0) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()