diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 8b0d7c3e20..c68b0a22aa 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -156,7 +156,7 @@ def _refit_single_trt_engine_with_gm( # Get the refitting mapping trt_wt_location = ( trt.TensorLocation.DEVICE - if torch_device == "cuda" + if torch_device.type == "cuda" else trt.TensorLocation.HOST ) mapping = construct_refit_mapping_from_weight_name_map(