From 4c17c2567075eb17c79123a185d87126a7e87052 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 1 Aug 2023 11:06:49 +0800 Subject: [PATCH] fix pytorch deepcopy trace error --- mmdeploy/pytorch/functions/__init__.py | 1 + mmdeploy/pytorch/functions/copy.py | 17 +++++++++++++++++ tests/test_pytorch/test_pytorch_functions.py | 20 ++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 mmdeploy/pytorch/functions/copy.py diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 19515408af..ba40970b3d 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -5,6 +5,7 @@ from . import cat # noqa: F401,F403 from . import chunk # noqa: F401,F403 from . import clip # noqa: F401,F403 +from . import copy # noqa: F401,F403 from . import expand # noqa: F401,F403 from . import flatten # noqa: F401,F403 from . import getattribute # noqa: F401,F403 diff --git a/mmdeploy/pytorch/functions/copy.py b/mmdeploy/pytorch/functions/copy.py new file mode 100644 index 0000000000..1d3cf9b190 --- /dev/null +++ b/mmdeploy/pytorch/functions/copy.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter(func_name='copy.deepcopy') +def copy__default(tensor: Tensor, *args, **kwargs) -> Tensor: + """Rewrite `copy.deepcopy` for default backend. + + Replace it with tensor.clone(), or may raise `NYI: Named tensors are not + supported with the tracer` + """ + ctx = FUNCTION_REWRITER.get_context() + if isinstance(tensor, Tensor) and args == () and kwargs == {}: + return tensor.clone() + return ctx.origin_func(tensor, *args, **kwargs) diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index f0313915ef..50481ac436 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -667,3 +667,23 @@ def test_cat__tensorrt(dtype, dynamic_axes): rewrite_output[0].cpu().float(), rtol=1e-3, atol=1e-5) + + +@backend_checker(Backend.TENSORRT) +def test_copy__default(): + import copy + input = torch.rand(2, 4) + model = WrapFunction( + lambda input: [copy.deepcopy(input) for i in range(3)]) + pytorch_output = model(input) + rewrite_output, _ = get_rewrite_outputs( + model, + model_inputs={'input': input}, + deploy_cfg=get_trt_config(['output'], shape=[2, 4], dynamic_axes=None), + run_with_backend=True) + for pytorch_out, rewrite_out in zip(pytorch_output, rewrite_output): + assert torch.allclose( + pytorch_out.cpu().float(), + rewrite_out.cpu().float(), + rtol=1e-3, + atol=1e-5)