Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul upsample dynamo converter #2790

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 327 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,9 +3041,175 @@ def aten_ops_pad(
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
def upsample_nearest2d(
for op in (
torch.ops.aten.upsample_nearest1d,
torch.ops.aten.upsample_nearest2d,
torch.ops.aten.upsample_nearest3d,
torch.ops.aten.upsample_linear1d,
torch.ops.aten.upsample_bilinear2d,
torch.ops.aten.upsample_trilinear3d,
torch.ops.aten.upsample_bicubic2d,
):
for key in (
torch._C.DispatchKey.Autograd,
torch._C.DispatchKey.CompositeImplicitAutograd,
):
if key in op.default.py_kernels:
del op.default.py_kernels[key]
if key in op.vec.py_kernels:
del op.vec.py_kernels[key]


def upsample_compute_output_size(
input_size: torch.Size,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> Sequence[int]:
spatial_dimensions = len(input_size) - 2

if output_size is not None:
torch._check(
scale_factors is None,
lambda: "Must specify exactly one of output_size and scale_factors",
)
torch._check(len(output_size) == spatial_dimensions)
return output_size

if scale_factors is not None:
torch._check(
output_size is None,
lambda: "Must specify exactly one of output_size and scale_factors",
)
torch._check(len(scale_factors) == spatial_dimensions)
output_size = []
for i, s in enumerate(scale_factors):
output_size.append(int(input_size[i + 2] * s))
return output_size

torch._check(
False, lambda: "Must specify exactly one of output_size and scale_factors"
)


@torch.ops.aten.upsample_nearest1d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_nearest1d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_nearest1d.default(input, osize, *scale_factors)
return torch.ops.aten.upsample_nearest1d.default(input, osize)


@torch.ops.aten.upsample_nearest2d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_nearest2d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_nearest2d.default(input, osize, *scale_factors)
return torch.ops.aten.upsample_nearest2d.default(input, osize)


@torch.ops.aten.upsample_nearest3d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_nearest3d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_nearest3d.default(input, osize, *scale_factors)
return torch.ops.aten.upsample_nearest3d.default(input, osize)


@torch.ops.aten.upsample_linear1d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_linear1d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_linear1d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_linear1d.default(input, osize, align_corners)


@torch.ops.aten.upsample_bilinear2d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_bilinear2d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_bilinear2d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_bilinear2d.default(input, osize, align_corners)


@torch.ops.aten.upsample_trilinear3d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_trilinear3d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_trilinear3d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_trilinear3d.default(input, osize, align_corners)


@torch.ops.aten.upsample_bicubic2d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_bicubic2d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_bicubic2d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_bicubic2d.default(input, osize, align_corners)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_nearest1d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3055,17 +3221,23 @@ def upsample_nearest2d(
target,
SourceIR.ATEN,
name,
input=args[0],
out_shape=args_bounds_check(args, 1),
scale_factors=args_bounds_check(args, 2),
resize_mode="nearest",
args[0],
size=args[1],
scale_factor=None if len(args) < 3 else [args[2]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
def upsample_bilinear2d(
@dynamo_tensorrt_converter(
torch.ops.aten.upsample_nearest2d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3077,11 +3249,151 @@ def upsample_bilinear2d(
target,
SourceIR.ATEN,
name,
input=args[0],
out_shape=args_bounds_check(args, 1),
scale_factors=args_bounds_check(args, 3),
resize_mode="bilinear",
align_corners=args_bounds_check(args, 2),
args[0],
size=args[1],
scale_factor=None if len(args) < 4 else [args[2], args[3]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_nearest3d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest3d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 5 else [args[2], args[3], args[4]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_linear1d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_linear1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 4 else [args[3]],
mode="linear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_bilinear2d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_bilinear2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 5 else [args[3], args[4]],
mode="bilinear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_trilinear3d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_trilinear3d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 6 else [args[3], args[4], args[5]],
mode="trilinear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_bicubic2d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_bicubic2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 5 else [args[3], args[4]],
mode="bicubic",
align_corners=args[2],
)


Expand Down
Loading
Loading