-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fft c2r #16
Fft c2r #16
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,25 @@ using ScalarType = framework::proto::VarType::Type; | |
const int64_t kMaxCUFFTNdim = 3; | ||
const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; | ||
|
||
std::ostream& operator<<(std::ostream& os, FFTTransformType fft_type) { | ||
std::string repr; | ||
switch (fft_type) { | ||
case FFTTransformType::C2C: | ||
repr = "C2C"; | ||
break; | ||
case FFTTransformType::C2R: | ||
repr = "C2R"; | ||
break; | ||
case FFTTransformType::R2C: | ||
repr = "R2C"; | ||
break; | ||
default: | ||
repr = "UNK"; | ||
} | ||
os << repr; | ||
return os; | ||
} | ||
|
||
static inline std::string get_cufft_error_info(cufftResult error) { | ||
switch (error) { | ||
case CUFFT_SUCCESS: | ||
|
@@ -278,7 +297,7 @@ struct KeyHash { | |
value ^= ptr[i]; | ||
value *= 0x01000193; | ||
} | ||
return (size_t)value; | ||
return static_cast<size_t>(value); | ||
} | ||
}; | ||
|
||
|
@@ -431,7 +450,10 @@ class PlanLRUCache { | |
// Execute a pre-planned transform | ||
static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, | ||
void* out_data, bool forward) { | ||
std::cout << "config address:" << &config << std::endl; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最终调试完成后,注意将辅助的cout/vlog等去除 |
||
auto& plan = config.plan(); | ||
std::cout << "inside exec_cufft_plan ==============--------" << std::endl; | ||
// std::cout<<"plan ==============--------"<< *plan << std::endl; | ||
#ifdef __HIPCC__ | ||
auto value_type = config.data_type(); | ||
if (value_type == framework::proto::VarType::FP32) { | ||
|
@@ -450,6 +472,8 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, | |
case FFTTransformType::C2R: { | ||
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(in_data), | ||
static_cast<hipfftReal*>(out_data))); | ||
std::cout << "inside FFTTransformType ==============--------" | ||
<< std::endl; | ||
return; | ||
} | ||
} | ||
|
@@ -478,8 +502,17 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, | |
PADDLE_THROW(platform::errors::InvalidArgument( | ||
"hipFFT only support transforms of type float32 and float64")); | ||
#else | ||
std::cout << "after __HIPCC__ ==============--------" << std::endl; | ||
std::cout << "plan: " << plan << std::endl; | ||
std::cout << "input pointer: " << in_data << std::endl; | ||
std::cout << "output pointer: " << out_data << std::endl; | ||
size_t ws = 0; | ||
cufftGetSize(plan, &ws); | ||
std::cout << "workspace size: " << ws << std::endl; | ||
Comment on lines
+505
to
+511
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最终调试完成后,注意将辅助的cout/vlog等去除 |
||
|
||
CUFFT_CHECK(cufftXtExec(plan, in_data, out_data, | ||
forward ? CUFFT_FORWARD : CUFFT_INVERSE)); | ||
std::cout << "end end end __HIPCC__ end ==============--------" << std::endl; | ||
#endif | ||
} | ||
|
||
|
@@ -605,6 +638,11 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X, | |
PlanKey Key(framework::vectorize(input.dims()), | ||
framework::vectorize(output.dims()), signal_size, fft_type, | ||
value_type); | ||
std::cout << "input.dims()" << input.dims() << std::endl; | ||
std::cout << "output.dims()" << output.dims() << std::endl; | ||
std::cout << "signal_size" << framework::make_ddim(signal_size) << std::endl; | ||
std::cout << "fft_type" << fft_type << std::endl; | ||
std::cout << "value_type" << value_type << std::endl; | ||
Comment on lines
+641
to
+645
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最终调试完成后,注意将辅助的cout/vlog等去除 |
||
PlanLRUCache& plan_cache = cufft_get_plan_cache(static_cast<int64_t>( | ||
(reinterpret_cast<platform::CUDAPlace*>(&tensor_place))->GetDeviceId())); | ||
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock); | ||
|
@@ -632,7 +670,6 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X, | |
|
||
// execute transform plan | ||
exec_cufft_plan(*config, input.data<void>(), output.data<void>(), forward); | ||
|
||
// Inverting output by reshape and transpose to original batch and dimension | ||
output.Resize(framework::make_ddim(reshape_out_sizes)); | ||
out->Resize(framework::make_ddim(out_sizes)); | ||
|
@@ -644,7 +681,9 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X, | |
reverse_dim_permute); | ||
} | ||
*/ | ||
std::cout << "before TransCompute" << std::endl; | ||
TransCompute<DeviceContext, To>(ndim, ctx, output, out, reverse_dim_permute); | ||
std::cout << "after TransCompute" << std::endl; | ||
} | ||
|
||
// Calculates the normalization constant and applies it in-place to out | ||
|
@@ -689,6 +728,19 @@ void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, | |
|
||
} // anonymous namespace | ||
|
||
// Use the optimized path to perform single R2C or C2R if transformation dim is | ||
// supported by cuFFT | ||
bool use_optimized_cufft_path(const std::vector<int64_t>& axes) { | ||
// For performance reason, when axes starts with (0, 1), do not use the | ||
// optimized path. | ||
if (axes.size() > kMaxCUFFTNdim || | ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) { | ||
return false; | ||
} else { | ||
return true; | ||
} | ||
} | ||
|
||
template <typename Ti, typename To> | ||
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> { | ||
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X, | ||
|
@@ -730,6 +782,45 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> { | |
} | ||
}; | ||
|
||
template <typename Ti, typename To> | ||
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> { | ||
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X, | ||
Tensor* out, const std::vector<int64_t>& axes, | ||
FFTNormMode normalization, bool forward) { | ||
std::vector<int64_t> in_dims = framework::vectorize(X->dims()); | ||
// std::vector<int64_t> out_dims(in_dims.begin(), in_dims.end()); | ||
// out_dims[axes.back()] = out->dims(); | ||
std::vector<int64_t> out_dims = framework::vectorize(out->dims()); | ||
|
||
std::cout << "axes: " << framework::make_ddim(axes) << std::endl; | ||
if (use_optimized_cufft_path(axes)) { | ||
std::cout << "befor exec --------" << std::endl; | ||
std::cout << "out dims: " << out->dims() << out->type() << std::endl; | ||
std::cout << "in dims: " << X->dims() << X->type() << std::endl; | ||
Comment on lines
+797
to
+799
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最终调试完成后,注意将辅助的cout/vlog等去除 |
||
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, out, X, out_dims, axes, | ||
forward); | ||
} else { | ||
framework::Tensor temp_tensor; | ||
const std::vector<int64_t> dims(axes.begin(), axes.end() - 1); | ||
|
||
FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor; | ||
c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward); | ||
|
||
exec_fft<platform::CUDADeviceContext, Ti, To>( | ||
ctx, out, &temp_tensor, out_dims, {axes.back()}, forward); | ||
} | ||
exec_normalization<platform::CUDADeviceContext, To>( | ||
ctx, out, out, normalization, out_dims, axes); | ||
} | ||
}; | ||
|
||
template <typename Ti, typename To> | ||
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> { | ||
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X, | ||
Tensor* out, const std::vector<int64_t>& axes, | ||
FFTNormMode normalization, bool forward, bool onesided) {} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
|
@@ -742,3 +833,21 @@ REGISTER_OP_CUDA_KERNEL( | |
fft_c2c_grad, | ||
ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
fft_c2r_grad, | ||
ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
fft_r2c_grad, | ||
ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否还需要这个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug 时加上的,因为 enum class 默认不能被 ostream<< 输出。