Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fft c2r #16

Merged
merged 2 commits into from
Sep 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions paddle/fluid/operators/spectral_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ class FFTC2ROpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddAttr<int64_t>(
"last_dim_size", "int",
"Length of the transformed "
"axis of the output. For n output points, last_dim_size//2 + 1 input"
" points are necessary. If the input is longer than this,"
" it is cropped. If it is shorter than this, it is padded"
" with zeros. If last_dim_size is not given, it is taken to be 2*(m-1)"
" where m is the length of the input along the axis "
"specified by axis.")
.SetDefault(0L);
AddComment(R"DOC(
// add doc here
)DOC");
Expand All @@ -259,10 +269,15 @@ class FFTC2ROp : public framework::OperatorWithKernel {
"Output(%s) of FFTC2ROp should not be null.", "Out"));
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");

const int64_t last_dim_size = ctx->Attrs().Get<int64_t>("last_dim_size");
framework::DDim out_dim(ctx->GetInputDim("X"));
const int64_t last_fft_axis = axes.back();
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
out_dim.at(last_fft_axis) = (last_fft_dim_size - 1) * 2;
if (last_dim_size == 0) {
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
out_dim.at(last_fft_axis) = (last_fft_dim_size - 1) * 2;
} else {
out_dim.at(last_fft_axis) = ctx->Attrs().Get<int64_t>("last_dim_size");
}
ctx->SetOutputDim("Out", out_dim);
}

Expand Down Expand Up @@ -840,7 +855,23 @@ template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {}
FFTNormMode normalization, bool forward) {
if (axes.size() > 1) {
const std::vector<int64_t> c2c_dims(axes.begin(), axes.end() - 1);
Tensor temp;
temp->mutable_data<Ti>(x->dims(), ctx.GetPlace());

FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);

const std::vector<int64_t> new_axes(axes.back());
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
normalization, forward);
} else {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
}
};

#elif defined(PADDLE_WITH_POCKETFFT)
Expand Down Expand Up @@ -955,7 +986,7 @@ struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
[](int64_t s) { return s * data_size; });
}

const auto* in_data = reinterpret_cast<const C*>(x->data<To>());
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = out->data<R>();
// well, we have to use std::vector<size_t> here
std::vector<size_t> axes_(axes.size());
Expand Down
113 changes: 111 additions & 2 deletions paddle/fluid/operators/spectral_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ using ScalarType = framework::proto::VarType::Type;
const int64_t kMaxCUFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1;

std::ostream& operator<<(std::ostream& os, FFTTransformType fft_type) {
std::string repr;
switch (fft_type) {
case FFTTransformType::C2C:
repr = "C2C";
break;
case FFTTransformType::C2R:
repr = "C2R";
break;
case FFTTransformType::R2C:
repr = "R2C";
break;
default:
repr = "UNK";
}
os << repr;
return os;
}

Comment on lines +37 to +55
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否还需要这个函数

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug 时加上的,因为 enum class 默认不能被 ostream<< 输出。

static inline std::string get_cufft_error_info(cufftResult error) {
switch (error) {
case CUFFT_SUCCESS:
Expand Down Expand Up @@ -278,7 +297,7 @@ struct KeyHash {
value ^= ptr[i];
value *= 0x01000193;
}
return (size_t)value;
return static_cast<size_t>(value);
}
};

Expand Down Expand Up @@ -431,7 +450,10 @@ class PlanLRUCache {
// Execute a pre-planned transform
static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
void* out_data, bool forward) {
std::cout << "config address:" << &config << std::endl;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终调试完成后,注意将辅助的cout/vlog等去除

auto& plan = config.plan();
std::cout << "inside exec_cufft_plan ==============--------" << std::endl;
// std::cout<<"plan ==============--------"<< *plan << std::endl;
#ifdef __HIPCC__
auto value_type = config.data_type();
if (value_type == framework::proto::VarType::FP32) {
Expand All @@ -450,6 +472,8 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
case FFTTransformType::C2R: {
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftReal*>(out_data)));
std::cout << "inside FFTTransformType ==============--------"
<< std::endl;
return;
}
}
Expand Down Expand Up @@ -478,8 +502,17 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
#else
std::cout << "after __HIPCC__ ==============--------" << std::endl;
std::cout << "plan: " << plan << std::endl;
std::cout << "input pointer: " << in_data << std::endl;
std::cout << "output pointer: " << out_data << std::endl;
size_t ws = 0;
cufftGetSize(plan, &ws);
std::cout << "workspace size: " << ws << std::endl;
Comment on lines +505 to +511
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终调试完成后,注意将辅助的cout/vlog等去除


CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
forward ? CUFFT_FORWARD : CUFFT_INVERSE));
std::cout << "end end end __HIPCC__ end ==============--------" << std::endl;
#endif
}

Expand Down Expand Up @@ -605,6 +638,11 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
PlanKey Key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type,
value_type);
std::cout << "input.dims()" << input.dims() << std::endl;
std::cout << "output.dims()" << output.dims() << std::endl;
std::cout << "signal_size" << framework::make_ddim(signal_size) << std::endl;
std::cout << "fft_type" << fft_type << std::endl;
std::cout << "value_type" << value_type << std::endl;
Comment on lines +641 to +645
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终调试完成后,注意将辅助的cout/vlog等去除

PlanLRUCache& plan_cache = cufft_get_plan_cache(static_cast<int64_t>(
(reinterpret_cast<platform::CUDAPlace*>(&tensor_place))->GetDeviceId()));
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
Expand Down Expand Up @@ -632,7 +670,6 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,

// execute transform plan
exec_cufft_plan(*config, input.data<void>(), output.data<void>(), forward);

// Inverting output by reshape and transpose to original batch and dimension
output.Resize(framework::make_ddim(reshape_out_sizes));
out->Resize(framework::make_ddim(out_sizes));
Expand All @@ -644,7 +681,9 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
reverse_dim_permute);
}
*/
std::cout << "before TransCompute" << std::endl;
TransCompute<DeviceContext, To>(ndim, ctx, output, out, reverse_dim_permute);
std::cout << "after TransCompute" << std::endl;
}

// Calculates the normalization constant and applies it in-place to out
Expand Down Expand Up @@ -689,6 +728,19 @@ void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,

} // anonymous namespace

// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
bool use_optimized_cufft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the
// optimized path.
if (axes.size() > kMaxCUFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false;
} else {
return true;
}
}

template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Expand Down Expand Up @@ -730,6 +782,45 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
}
};

template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
std::vector<int64_t> in_dims = framework::vectorize(X->dims());
// std::vector<int64_t> out_dims(in_dims.begin(), in_dims.end());
// out_dims[axes.back()] = out->dims();
std::vector<int64_t> out_dims = framework::vectorize(out->dims());

std::cout << "axes: " << framework::make_ddim(axes) << std::endl;
if (use_optimized_cufft_path(axes)) {
std::cout << "befor exec --------" << std::endl;
std::cout << "out dims: " << out->dims() << out->type() << std::endl;
std::cout << "in dims: " << X->dims() << X->type() << std::endl;
Comment on lines +797 to +799
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终调试完成后,注意将辅助的cout/vlog等去除

exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, out, X, out_dims, axes,
forward);
} else {
framework::Tensor temp_tensor;
const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);

FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);

exec_fft<platform::CUDADeviceContext, Ti, To>(
ctx, out, &temp_tensor, out_dims, {axes.back()}, forward);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, out, out, normalization, out_dims, axes);
}
};

template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward, bool onesided) {}
};

} // namespace operators
} // namespace paddle

Expand All @@ -742,3 +833,21 @@ REGISTER_OP_CUDA_KERNEL(
fft_c2c_grad,
ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
fft_c2r_grad,
ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
fft_r2c_grad,
ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>);
10 changes: 5 additions & 5 deletions paddle/fluid/operators/spectral_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ enum class FFTNormMode : int64_t {
FFTNormMode get_norm_from_string(const std::string& norm, bool forward);

// Enum representing the FFT type
enum class FFTTransformType : int8_t {
C2C, // Complex-to-complex
R2C, // Real-to-complex
C2R, // Complex-to-real
enum class FFTTransformType : int64_t {
C2C = 0, // Complex-to-complex
R2C, // Real-to-complex
C2R, // Complex-to-real
};

// Create transform type enum from bools representing if input and output are
Expand Down Expand Up @@ -99,7 +99,7 @@ template <typename DeviceContext, typename Ti, typename To>
struct FFTC2RFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward, bool onesided);
bool forward);
};

template <typename DeviceContext, typename T>
Expand Down
26 changes: 10 additions & 16 deletions python/paddle/tensor/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,16 +409,16 @@ def fftn_c2r(x, s, axes, norm, forward):
raise ValueError(
"Unexpected norm: {}. Norm should be forward, backward or ortho".
form(norm))
s = list(s)
rank = x.ndim
if axes is None:
if s is None:
axes = list(range(rank))
s = paddle.shape(x)
else:
fft_ndims = len(s)
axes = list(range(rank - fft_ndims, rank))
else:
axes_ = axes.copy()
axes_ = list(axes)
for i in len(axes_):
if axes_[i] < -rank or axes_[i] >= rank:
raise ValueError(
Expand All @@ -427,26 +427,20 @@ def fftn_c2r(x, s, axes, norm, forward):
if axes_[i] < 0:
axes_[i] += rank
axes = axes_
axes.sort()
if s is None:
shape = paddle.shape(x)
s = [shape[axis] for axis in axes]
else:
assert len(axes) == len(s)

op_type = 'fft_c2r'

if in_dygraph_mode():
attrs = ('s', s, 'axes', axes, 'normalization', norm, 'forward',
forward)
if s:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'last_dim_size', s[-1])
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {
's': s,
'axes': axes,
'normalization': norm,
'forward': forward
}
attrs = {'axes': axes, 'normalization': norm, 'forward': forward}
if s:
attr["last_dim_size"] = s[-1]
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)
helper = LayerHelper(op_type, **locals())
Expand Down