From 1fa514158a5daf011eac0a7230ef47260db5b7ca Mon Sep 17 00:00:00 2001 From: cehongwang <123616592+cehongwang@users.noreply.github.com> Date: Tue, 2 Jul 2024 11:29:06 -0700 Subject: [PATCH] Implemented basic pipeline for Refitting (#2886) --- py/torch_tensorrt/dynamo/_compiler.py | 12 ++++++++++++ py/torch_tensorrt/dynamo/_refit.py | 8 -------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d562277aeb..43c0f8699a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -178,6 +178,18 @@ def compile( if kwarg_inputs is None: kwarg_inputs = {} + + if "refit" in kwargs.keys(): + warnings.warn( + "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + DeprecationWarning, + stacklevel=2, + ) + if make_refitable: + raise ValueError("Use flag make_refitable only. Flag refit is deprecated.") + else: + make_refitable = kwargs["refit"] + engine_capability = EngineCapability._from(engine_capability) if torch_executed_modules is not None and torch_executed_modules: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 98cab802c9..38810e59b3 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -6,10 +6,7 @@ from typing import Any, Sequence, Tuple import numpy as np -<<<<<<< HEAD -======= import tensorrt as trt ->>>>>>> 9f46d3940 (Implemented basic pipeline for Refitting (#2886)) import torch from torch.export import ExportedProgram from torch_tensorrt._enums import dtype @@ -46,11 +43,6 @@ ) from torch_tensorrt.logging import TRT_LOGGER -<<<<<<< HEAD -import tensorrt as trt - -======= ->>>>>>> 9f46d3940 (Implemented basic pipeline for Refitting (#2886)) logger = logging.getLogger(__name__)