Skip to content

Commit

Permalink
fix irfft invalid descriptor (#4165) (#4480)
Browse files Browse the repository at this point in the history
Co-authored-by: hjhee <hjhee@users.noreply.github.com>
  • Loading branch information
CuiYifeng and hjhee authored Jul 16, 2024
1 parent d698764 commit 3e60e87
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
9 changes: 7 additions & 2 deletions csrc/gpu/aten/operators/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,13 @@ void _mkl_dft(
auto ostrides = output.strides();
int64_t idist = istrides[0];
int64_t odist = ostrides[0];
desc_config->set_value(config_param::FWD_DISTANCE, idist);
desc_config->set_value(config_param::BWD_DISTANCE, odist);
if (!inverse) {
desc_config->set_value(config_param::FWD_DISTANCE, idist);
desc_config->set_value(config_param::BWD_DISTANCE, odist);
} else {
desc_config->set_value(config_param::FWD_DISTANCE, odist);
desc_config->set_value(config_param::BWD_DISTANCE, idist);
}
std::vector<int64_t> mkl_istrides(1 + signal_ndim, 0),
mkl_ostrides(1 + signal_ndim, 0);
for (int64_t i = 1; i <= signal_ndim; i++) {
Expand Down
22 changes: 13 additions & 9 deletions tests/gpu/examples/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,21 @@ def test_fftn(self, dtype=torch.float):
self.assertEqual(y2, y2_dpcpp.cpu())

def test_irfft(self, dtype=torch.float):
x1 = torch.randn(5, 5)
x2 = torch.randn(4, 3, 2)
x1_dpcpp = x1.to("xpu")
x2_dpcpp = x2.to("xpu")
y1 = torch.fft.rfft(x1, 2)
y2 = torch.fft.irfft(y1, 2)
shapes = [[5, 5], [4, 3, 2], [2, 32, 6], [65, 80, 115]]

y1_dpcpp = torch.fft.rfft(x1_dpcpp, 2)
y2_dpcpp = torch.fft.irfft(y1_dpcpp, 2)
for shape in shapes:
fft_input = torch.randn(*shape, dtype=dtype)
fft_input = torch.fft.rfft(fft_input)

self.assertEqual(y2, y2_dpcpp.cpu())
x1 = fft_input
y1 = torch.fft.irfft(x1, n=shape[-1])
y2 = torch.fft.rfft(y1)

x1_dpcpp = fft_input.to(device="xpu")
y1_dpcpp = torch.fft.irfft(x1_dpcpp, n=shape[-1])
y2_dpcpp = torch.fft.rfft(y1_dpcpp)

self.assertEqual(y2, y2_dpcpp.cpu())

def test_ifft2(self, dtype=torch.float):
f_real = torch.randn(2, 72, 72)
Expand Down

0 comments on commit 3e60e87

Please sign in to comment.