From 96282bfa35a41eabadd97b2d0ffab2a67e1699e2 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Sun, 8 Sep 2024 12:29:00 +0800 Subject: [PATCH] _refit: Properly compare device type --- py/torch_tensorrt/dynamo/_refit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 8b0d7c3e20..c68b0a22aa 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -156,7 +156,7 @@ def _refit_single_trt_engine_with_gm( # Get the refitting mapping trt_wt_location = ( trt.TensorLocation.DEVICE - if torch_device == "cuda" + if torch_device.type == "cuda" else trt.TensorLocation.HOST ) mapping = construct_refit_mapping_from_weight_name_map(