diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9c42963b51..01b15aa533 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3630,7 +3630,7 @@ def acc_ops_interpolate( else: layer.resize_mode = trt.ResizeMode.NEAREST - if align_corners != None: + if (align_corners is not None) and align_corners: layer.coordinate_transformation = ( trt.ResizeCoordinateTransformation.ALIGN_CORNERS ) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py index f0054e5cb7..1f4b37f4b4 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py @@ -43,6 +43,14 @@ class TestInterpolateConverter(AccTestCase): ("bilinear"), (None), ), # linear for 4D only + ( + "4d_dim_scale_bilinear_align_corners_bool", + (2, 3, 4, 5), + (None), + (2), + ("bilinear"), + (False), + ), # linear for 4D only ( "4d_dim_scale_align", (2, 3, 4, 5),