Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex type compatibility for stft api and stft op. #40113

Merged
merged 11 commits into from
Mar 23, 2022
20 changes: 14 additions & 6 deletions paddle/fluid/operators/frame_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,26 @@ class FrameOp : public framework::OperatorWithKernel {
end_axis = x_rank - 2;
}

PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
}

// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}

n_frames = 1 + (seq_length - frame_length) / hop_length;
if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}

if (axis == 0) {
// (n_frames, frame_length, ...)
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,17 @@ REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
paddle::platform::bfloat16>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
paddle::platform::bfloat16>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
11 changes: 9 additions & 2 deletions paddle/fluid/operators/mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
paddle::platform::complex<float>>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
23 changes: 16 additions & 7 deletions paddle/fluid/operators/overlap_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class OverlapAddOp : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int seq_length;

int start_axis;
int end_axis;
Expand All @@ -69,14 +70,22 @@ class OverlapAddOp : public framework::OperatorWithKernel {
end_axis = x_rank - 3;
}

PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
}

const int seq_length = (n_frames - 1) * hop_length + frame_length;
if (n_frames == -1 && frame_length == -1) {
seq_length = -1;
} else {
seq_length = (n_frames - 1) * hop_length + frame_length;
}

// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
Expand Down
10 changes: 7 additions & 3 deletions paddle/fluid/operators/pad3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,10 @@ REGISTER_OPERATOR(pad3d_grad, ops::Pad3dOpGrad,
ops::Pad3dOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(pad3d, ops::Pad3dCPUKernel<float>,
ops::Pad3dCPUKernel<double>, ops::Pad3dCPUKernel<int>,
ops::Pad3dCPUKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(pad3d_grad, ops::Pad3dGradCPUKernel<float>,
ops::Pad3dGradCPUKernel<double>);
ops::Pad3dCPUKernel<int64_t>,
ops::Pad3dCPUKernel<paddle::platform::complex<float>>,
ops::Pad3dCPUKernel<paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
pad3d_grad, ops::Pad3dGradCPUKernel<float>, ops::Pad3dGradCPUKernel<double>,
ops::Pad3dGradCPUKernel<paddle::platform::complex<float>>,
ops::Pad3dGradCPUKernel<paddle::platform::complex<double>>);
18 changes: 11 additions & 7 deletions paddle/fluid/operators/pad3d_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -784,10 +784,14 @@ class Pad3dGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(pad3d, ops::Pad3dCUDAKernel<plat::float16>,
ops::Pad3dCUDAKernel<float>,
ops::Pad3dCUDAKernel<double>, ops::Pad3dCUDAKernel<int>,
ops::Pad3dCUDAKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(pad3d_grad, ops::Pad3dGradCUDAKernel<plat::float16>,
ops::Pad3dGradCUDAKernel<float>,
ops::Pad3dGradCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
pad3d, ops::Pad3dCUDAKernel<plat::float16>, ops::Pad3dCUDAKernel<float>,
ops::Pad3dCUDAKernel<double>, ops::Pad3dCUDAKernel<int>,
ops::Pad3dCUDAKernel<int64_t>,
ops::Pad3dCUDAKernel<paddle::platform::complex<float>>,
ops::Pad3dCUDAKernel<paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
pad3d_grad, ops::Pad3dGradCUDAKernel<plat::float16>,
ops::Pad3dGradCUDAKernel<float>, ops::Pad3dGradCUDAKernel<double>,
ops::Pad3dGradCUDAKernel<paddle::platform::complex<float>>,
ops::Pad3dGradCUDAKernel<paddle::platform::complex<double>>);
Loading