From 91b8f6cf02b370907ea19295c4bc108747e24cde Mon Sep 17 00:00:00 2001 From: HolyWu Date: Wed, 11 Sep 2024 04:58:20 +0800 Subject: [PATCH] _refit: Properly compare device type (#3149) --- py/torch_tensorrt/dynamo/_refit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 8b0d7c3e20..c68b0a22aa 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -156,7 +156,7 @@ def _refit_single_trt_engine_with_gm( # Get the refitting mapping trt_wt_location = ( trt.TensorLocation.DEVICE - if torch_device == "cuda" + if torch_device.type == "cuda" else trt.TensorLocation.HOST ) mapping = construct_refit_mapping_from_weight_name_map(