Skip to content

Commit

Permalink
_refit: Properly compare device type (#3149)
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu authored Sep 10, 2024
1 parent 8154408 commit 91b8f6c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _refit_single_trt_engine_with_gm(
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device == "cuda"
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
mapping = construct_refit_mapping_from_weight_name_map(
Expand Down

0 comments on commit 91b8f6c

Please sign in to comment.