diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 60215801e4c..109f0ac0592 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -1087,8 +1087,14 @@ at::Tensor conv_transpose3d( Tensor input; bool is_batched; std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d"); - auto output = at::convolution( + Tensor output; + if (at::isComplexType(input_.scalar_type())) { + output = complex_convolution( input, weight, bias, stride, padding, dilation, true, output_padding, groups); + } else { + output = at::convolution( + input, weight, bias, stride, padding, dilation, true, output_padding, groups); + } return is_batched ? output : output.squeeze(0); } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e8c7b419498..ba2e8bc492c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10673,6 +10673,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo( toleranceOverride({torch.complex32: tol(atol=1e-5, rtol=5e-3)}), "TestCudaFuserOpInfo", "test_nvfuser_correctness"), + DecorateInfo( + toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }), + 'TestCommon', 'test_numpy_ref_mps'), ), skips=( # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486 @@ -10741,8 +10744,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): OpInfo('nn.functional.conv_transpose3d', aten_name='conv_transpose3d', aliases=('conv_transpose3d',), - dtypes=floating_types_and(torch.int64), - dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d), + dtypes=floating_and_complex_types_and(torch.int64), + dtypesIfCUDA=floating_and_complex_types_and( + torch.float16, torch.chalf, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), sample_inputs_func=sample_inputs_conv_transpose3d, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -10752,24 +10759,45 @@ def reference_flatten(input, start_dim=0, end_dim=-1): gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, decorators=[ DecorateInfo( - toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), DecorateInfo( toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }), 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo( - toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), }), + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), DecorateInfo( toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }), 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda', - active_if=TEST_CUDNN)], + active_if=TEST_CUDNN), + DecorateInfo( + toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}), + "TestCudaFuserOpInfo", "test_nvfuser_correctness"), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}), + "TestMathBits", "test_conj_view", device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }), + 'TestCommon', 'test_complex_half_reference_testing')], skips=( # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.skip("Skipped! 75029"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), DecorateInfo(unittest.skip("Skipped! 75363"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'), + # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), ), supports_out=False,), OpInfo('nn.functional.conv1d', diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index fed908e14dd..1f395cbe606 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1179,6 +1179,7 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train )), ModuleInfo(torch.nn.ConvTranspose3d, module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True), + dtypes=floating_and_complex_types_and(torch.chalf), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, module_memformat_affects_out=True, skips=( @@ -1190,9 +1191,28 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # This was wrongly being skipped before and needs investigation. # See https://github.com/pytorch/pytorch/issues/80247 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), + # These fail only on ROCm + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM), + # Not implmented for chalf on CPU + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward', + dtypes=(torch.chalf,), device_type='cpu'), + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format', + dtypes=(torch.chalf,), device_type='cpu'), + DecorateInfo(unittest.expectedFailure, 'TestModule', + 'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'), + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors', + dtypes=(torch.chalf,), device_type='cpu'), + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', + dtypes=(torch.chalf,), device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer', + dtypes=(torch.chalf,), device_type='cuda'), + # Ref: https://github.com/pytorch/pytorch/issues/73502 + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'), )), ModuleInfo(torch.nn.ELU, module_inputs_func=module_inputs_torch_nn_ELU,