Skip to content

Commit

Permalink
Fix incomplete upsample dynamo converter
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Jul 9, 2024
1 parent 853556e commit 35c504c
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 95 deletions.
142 changes: 131 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,9 +3004,43 @@ def aten_ops_pad(
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest1d.default)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None,
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest1d.vec)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
def upsample_nearest2d(
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest3d.vec)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest_vec(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3018,17 +3052,51 @@ def upsample_nearest2d(
target,
SourceIR.ATEN,
name,
input=args[0],
out_shape=args_bounds_check(args, 1),
scale_factors=args_bounds_check(args, 2),
resize_mode="nearest",
args[0],
size=args_bounds_check(args, 1),
scale_factor=args_bounds_check(args, 2),
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_linear1d.default)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_trilinear3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_linear_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None,
mode="linear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_linear1d.vec)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
def upsample_bilinear2d(
@dynamo_tensorrt_converter(torch.ops.aten.upsample_trilinear3d.vec)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_linear_vec(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3040,11 +3108,63 @@ def upsample_bilinear2d(
target,
SourceIR.ATEN,
name,
input=args[0],
out_shape=args_bounds_check(args, 1),
scale_factors=args_bounds_check(args, 3),
resize_mode="bilinear",
align_corners=args_bounds_check(args, 2),
args[0],
size=args_bounds_check(args, 1),
scale_factor=args_bounds_check(args, 3),
mode="linear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_bicubic2d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_bicubic_default(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None,
mode="bicubic",
align_corners=args[2],
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_bicubic2d.vec)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_bicubic_vec(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args_bounds_check(args, 1),
scale_factor=args_bounds_check(args, 3),
mode="bicubic",
align_corners=args[2],
)


Expand Down
69 changes: 25 additions & 44 deletions py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,35 @@ def upsample(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
out_shape: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
resize_mode: str,
size: Optional[Sequence[int]],
scale_factor: Optional[Sequence[float]],
mode: str,
align_corners: bool,
) -> TRTTensor:
resize_layer = ctx.net.add_resize(input)
# output size calculation
# Pytorch assumes that one of out_shape/scale_factor is None
# Pytorch assumes that dimensions match for out_shape/scale factor
if out_shape is not None:
resize_layer.shape = list(input.shape)[:2] + list(out_shape)
elif scale_factors is not None:
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
else:
raise RuntimeError(
"At least one of out_shape and scale_factors should be specified."
)
layer = ctx.net.add_resize(input)

# interpolate mode
if resize_mode == "nearest" or None:
resize_layer.resize_mode = trt.InterpolationMode.NEAREST
elif resize_mode == "bilinear":
resize_layer.resize_mode = trt.InterpolationMode.LINEAR
if align_corners is None or not align_corners:
raise RuntimeError(
f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT."
)
if size is not None:
layer.shape = list(input.shape)[:2] + list(size)
else:
raise RuntimeError(
f"Interpolation mode is {resize_mode} which is not supported by TensorRT."
)
layer.scales = [1.0, 1.0] + list(scale_factor)

if resize_mode == "nearest":
resize_layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ASYMMETRIC
if mode == "nearest":
layer.resize_mode = trt.InterpolationMode.NEAREST
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ASYMMETRIC
elif mode in ("linear", "bilinear", "trilinear"):
layer.resize_mode = trt.InterpolationMode.LINEAR
layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
if align_corners
else trt.ResizeCoordinateTransformation.HALF_PIXEL
)
elif mode == "bicubic":
layer.resize_mode = trt.InterpolationMode.CUBIC
layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
if align_corners
else trt.ResizeCoordinateTransformation.HALF_PIXEL
)
elif resize_mode == "bilinear":
# align corners
if align_corners is not None and align_corners:
resize_layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
)
else:
resize_layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ASYMMETRIC
)

set_layer_name(resize_layer, target, name, source_ir)

out = resize_layer.get_output(0)
return out
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
aten.upsample_bilinear2d,
aten.upsample_bilinear2d.vec,
aten.upsample_nearest2d_backward,
aten.var,
aten.var_mean,
Expand Down
Loading

0 comments on commit 35c504c

Please sign in to comment.