Skip to content

Commit

Permalink
[complex] conv_transpose3d : complex support (#87967)
Browse files Browse the repository at this point in the history
Reference: pytorch/pytorch#71108

Pull Request resolved: pytorch/pytorch#87967
Approved by: https://github.com/anjali411
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Nov 2, 2022
1 parent 7674af9 commit e763b7a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
8 changes: 7 additions & 1 deletion aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,14 @@ at::Tensor conv_transpose3d(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
auto output = at::convolution(
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
} else {
output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
}
return is_batched ? output : output.squeeze(0);
}

Expand Down
38 changes: 33 additions & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10673,6 +10673,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
DecorateInfo(
toleranceOverride({torch.complex32: tol(atol=1e-5, rtol=5e-3)}),
"TestCudaFuserOpInfo", "test_nvfuser_correctness"),
DecorateInfo(
toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }),
'TestCommon', 'test_numpy_ref_mps'),
),
skips=(
# Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486
Expand Down Expand Up @@ -10741,8 +10744,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
OpInfo('nn.functional.conv_transpose3d',
aten_name='conv_transpose3d',
aliases=('conv_transpose3d',),
dtypes=floating_types_and(torch.int64),
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
# `ref` for this function is backward of
# corresponding `conv*d`
ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d),
dtypes=floating_and_complex_types_and(torch.int64),
dtypesIfCUDA=floating_and_complex_types_and(
torch.float16, torch.chalf, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_conv_transpose3d,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand All @@ -10752,24 +10759,45 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
decorators=[
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06),
torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }),
'TestCompositeCompliance', 'test_operator', device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), }),
toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06),
torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }),
'TestCompositeCompliance', 'test_forward_ad', device_type='cuda',
active_if=TEST_CUDNN)],
active_if=TEST_CUDNN),
DecorateInfo(
toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}),
"TestCudaFuserOpInfo", "test_nvfuser_correctness"),
DecorateInfo(
toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}),
"TestMathBits", "test_conj_view", device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }),
'TestCommon', 'test_complex_half_reference_testing')],
skips=(
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.skip("Skipped! 75029"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
DecorateInfo(unittest.skip("Skipped! 75363"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
# RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long'
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
dtypes=(torch.int64,)),
# Reference: https://github.com/pytorch/pytorch/issues/86356
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
dtypes=(torch.double, torch.cdouble)),
DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
# RuntimeError: UNSUPPORTED DTYPE: complex
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
dtypes=(torch.complex64, torch.complex128)),
),
supports_out=False,),
OpInfo('nn.functional.conv1d',
Expand Down
20 changes: 20 additions & 0 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train
)),
ModuleInfo(torch.nn.ConvTranspose3d,
module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
dtypes=floating_and_complex_types_and(torch.chalf),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_memformat_affects_out=True,
skips=(
Expand All @@ -1190,9 +1191,28 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
# These fail only on ROCm
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
# Not implmented for chalf on CPU
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule',
'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
dtypes=(torch.chalf,), device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
dtypes=(torch.chalf,), device_type='cuda'),
# Ref: https://github.com/pytorch/pytorch/issues/73502
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
)),
ModuleInfo(torch.nn.ELU,
module_inputs_func=module_inputs_torch_nn_ELU,
Expand Down

0 comments on commit e763b7a

Please sign in to comment.