From 0a075f0c4c3515fd600662ba6499e55657dac4e2 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 20 Dec 2022 11:25:45 -0500 Subject: [PATCH] fix: Bugfix for `align_corners=False`- FX interpolate (#1561) --- py/torch_tensorrt/fx/converters/acc_ops_converters.py | 2 +- .../fx/test/converters/acc_op/test_interpolate.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9c42963b51..01b15aa533 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3630,7 +3630,7 @@ def acc_ops_interpolate( else: layer.resize_mode = trt.ResizeMode.NEAREST - if align_corners != None: + if (align_corners is not None) and align_corners: layer.coordinate_transformation = ( trt.ResizeCoordinateTransformation.ALIGN_CORNERS ) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py index f0054e5cb7..1f4b37f4b4 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py @@ -43,6 +43,14 @@ class TestInterpolateConverter(AccTestCase): ("bilinear"), (None), ), # linear for 4D only + ( + "4d_dim_scale_bilinear_align_corners_bool", + (2, 3, 4, 5), + (None), + (2), + ("bilinear"), + (False), + ), # linear for 4D only ( "4d_dim_scale_align", (2, 3, 4, 5),