Skip to content

Commit

Permalink
Give precedence to scale_factor over size in TRT layer
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Jul 14, 2024
1 parent 0a51a45 commit 13323b2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 9 deletions.
112 changes: 106 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,14 +3014,64 @@ def aten_ops_pad(


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest1d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest1d_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 3 else [args[2]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest2d_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 3 else [args[2], args[3]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest_default(
def aten_ops_upsample_nearest3d_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3035,7 +3085,7 @@ def aten_ops_upsample_nearest_default(
name,
args[0],
size=args[1],
scale_factor=None,
scale_factor=None if len(args) < 3 else [args[2], args[3], args[4]],
mode="nearest",
align_corners=False,
)
Expand Down Expand Up @@ -3070,14 +3120,64 @@ def aten_ops_upsample_nearest_vec(


@dynamo_tensorrt_converter(torch.ops.aten.upsample_linear1d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_linear1d_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 4 else [args[3]],
mode="linear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_bilinear2d_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 4 else [args[3], args[4]],
mode="bilinear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_trilinear3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_linear_default(
def aten_ops_upsample_trilinear3d_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3091,8 +3191,8 @@ def aten_ops_upsample_linear_default(
name,
args[0],
size=args[1],
scale_factor=None,
mode="linear",
scale_factor=None if len(args) < 4 else [args[3], args[4], args[5]],
mode="trilinear",
align_corners=args[2],
)

Expand Down Expand Up @@ -3145,7 +3245,7 @@ def aten_ops_upsample_bicubic_default(
name,
args[0],
size=args[1],
scale_factor=None,
scale_factor=None if len(args) < 4 else [args[3], args[4]],
mode="bicubic",
align_corners=args[2],
)
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def upsample(
) -> TRTTensor:
layer = ctx.net.add_resize(input)

if size is not None:
layer.shape = list(input.shape)[:2] + list(size)
else:
if scale_factor is not None:
layer.scales = [1.0, 1.0] + list(scale_factor)
else:
layer.shape = list(input.shape)[:2] + list(size)

if mode == "nearest":
layer.resize_mode = trt.InterpolationMode.NEAREST
Expand Down

0 comments on commit 13323b2

Please sign in to comment.