Skip to content

Commit

Permalink
Merge pull request #6 from jeff41404/fft_c2c_cufft
Browse files Browse the repository at this point in the history
fft c2c cufft kernel done with compiling and linking
  • Loading branch information
Feiyu Chan authored Aug 19, 2021
2 parents 322b9e3 + 9f0cc98 commit 96e8b09
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 59 deletions.
9 changes: 8 additions & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ if(WITH_UNITY_BUILD)
endif()

register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})

op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})

Expand All @@ -91,6 +91,13 @@ if (WITH_GPU OR WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
find_library(CUFFT_LIB libcufft.so
PATHS
${CUDA_TOOLKIT_ROOT_DIR}/lib64/
NO_DEFAULT_PATH
)
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS})
target_link_libraries(spectral_op ${CUFFT_LIB})
endif()
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
Expand Down
116 changes: 64 additions & 52 deletions paddle/fluid/operators/spectral_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -201,8 +199,6 @@ FFTTransformType::C2R);
}
auto out_layout = as_cufft_embed(out_strides, sizes, fft_type ==
FFTTransformType::R2C);
TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be
represented as CuFFT embedding");
clone_input |= in_layout.must_clone;
// Check if we can take advantage of simple data layout.
Expand Down Expand Up @@ -253,7 +249,8 @@ represented as CuFFT embedding");
exec_type = CUDA_C_16F;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuFFT doesn't support tensor of type: [%s]", dtype));
"cuFFT only support transforms of type float16, float32 and "
"float64"));
}
#endif

Expand Down Expand Up @@ -404,9 +401,7 @@ class PlanLRUCache {

// If key is in this cache, return the cached config. Otherwise, emplace the
// config in this cache and return it.
// Return const reference because CuFFTConfig shouldn't be tampered with once
// created.
const CuFFTConfig& lookup(PlanKey params) {
CuFFTConfig& lookup(PlanKey params) {
PADDLE_ENFORCE_GT(_max_size, 0,
platform::errors::InvalidArgument(
"The max size of PlanLRUCache must be great than 0,"
Expand Down Expand Up @@ -536,7 +531,7 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT doesn't support transforms on type: [%s]", value_type));
"hipFFT only support transforms of type float32 and float64"));
#else
CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
forward ? CUFFT_FORWARD : CUFFT_INVERSE));
Expand Down Expand Up @@ -570,15 +565,18 @@ static inline PlanLRUCache& cufft_get_plan_cache(int64_t device_index) {
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
template <typename DeviceContext, typename T>
void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
const std::vector<int64_t>& out_sizes,
const std::vector<int64_t>& dim, bool forward) {
const auto x_dims = X->dim() const auto ndim =
static_cast<int64_t>(X->dim().size());
const std::vector<int64_t> out_sizes,
const std::vector<int64_t> dim, bool forward) {
const auto x_dims = framework::vectorize(X->dims());
const auto ndim = static_cast<int64_t>(X->dims().size());
const int64_t signal_ndim = dim.size();
const auto batch_dims = ndim - signal_ndim;
auto tensor_place = ctx.GetPlace();

// Transpose batch dimensions first, then with transforming dims
std::vector<int> dim_permute(ndim);
std::vector<int> reverse_dim_permute(ndim);
std::vector<int64_t> trans_dims(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int{0});
std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) {
Expand All @@ -589,17 +587,22 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
[&](int64_t d) { return !is_transformed_dim[d]; });
std::sort(dim_permute.begin(), batch_end);
std::copy(dim.cbegin(), dim.cend(), batch_end);
framework::DDim trans_dims(X->dim());

for (size_t i = 0; i < ndim; i++) {
trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose
reverse_dim_permute[dim_permute[i]] = i; // reverse of dim permute
trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose
reverse_dim_permute[dim_permute[i]] =
static_cast<int>(i); // reverse of dim permute
}
framework::Tensor input;
input.Resize(trans_dims) input.mutable_data<T>(ctx.GetPlace());
auto ret = TransposeSimple<T>::run(ctx, *X, dim_permute, input);
if (!ret) {
framework::Tensor* input;
input->Resize(framework::make_ddim(trans_dims));
input->mutable_data<T>(tensor_place);
/*
auto in_ret = TransposeSimple<T>::run(ctx, *X, dim_permute, input);
if (!in_ret) {
TransCompute<DeviceContext, T>(ndim, ctx, *X, input, dim_permute);
}
*/
TransCompute<DeviceContext, T>(ndim, ctx, *X, input, dim_permute);

// Reshape batch dimensions into a single dimension
std::vector<int64_t> batched_sizes(signal_ndim + 1);
Expand All @@ -608,7 +611,7 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
std::multiplies<int>());
batched_sizes[0] = batch_size;
std::copy(dim.cbegin(), dim.cend(), batched_sizes.begin() + 1);
input->Resize(batched_sizes);
input->Resize(framework::make_ddim(batched_sizes));

// Check the shape of transforming dims with input and output
std::vector<int64_t> signal_size(signal_ndim + 1);
Expand Down Expand Up @@ -648,19 +651,20 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
}

// output
framework::Tensor output;
output->Resize(batched_out_sizes) output.mutable_data<T>(ctx.GetPlace());
framework::Tensor* output;
output->Resize(framework::make_ddim(batched_out_sizes));
output->mutable_data<T>(tensor_place);

// Create the transform plan (either from cache or locally)
const auto value_type = framework::ToRealType(input.type());
auto fft_type = GetFFTTransformType(input.type(), output.type());
PlanKey Key(framework::vectorize(input->dim()),
framework::vectorize(output->dim()), signal_size, fft_type,
const auto value_type = framework::ToRealType(input->type());
auto fft_type = GetFFTTransformType(input->type(), output->type());
PlanKey Key(framework::vectorize(input->dims()),
framework::vectorize(output->dims()), signal_size, fft_type,
value_type);
PlanLRUCache& plan_cache = cufft_get_plan_cache(input.device().index());
PlanLRUCache& plan_cache = cufft_get_plan_cache(static_cast<int64_t>(
(reinterpret_cast<platform::CUDAPlace*>(&tensor_place))->GetDeviceId()));
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
c10::optional<CuFFTConfig> uncached_plan;
const CuFFTConfig* config = nullptr;
CuFFTConfig* config = nullptr;

if (plan_cache.max_size() > 0) {
guard.lock();
Expand All @@ -670,32 +674,34 @@ void exec_fft(const DeviceContext& ctx, Tensor* out, const Tensor* X,
}

if (config == nullptr) {
uncached_plan.emplace(Key);
config = &uncached_plan.value();
CuFFTConfig uncached_plan(Key);
config = &uncached_plan;
}

auto& plan = config->plan();

// prepare cufft for execution
// CUFFT_CHECK(cufftSetStream(plan, reinterpret_cast<const
// platform::CUDADeviceContext&>(ctx).stream()));
CUFFT_CHECK(cufftSetStream(plan, ctx.stream()));
framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<T>(ctx.GetPlace(),
requested_size = config->workspace_size());
CUFFT_CHECK(cufftSetWorkArea(plan, workspace.data()));
workspace_tensor.mutable_data<T>(tensor_place, config->workspace_size());
CUFFT_CHECK(cufftSetWorkArea(plan, workspace_tensor.data<T>()));

// execute transform plan
exec_cufft_plan(*config, input.data_ptr(), output.data_ptr(), forward);
exec_cufft_plan(*config, input->data<T>(), output->data<T>(), forward);

// Inverting output by reshape and transpose to original batch and dimension
output->Resize(reshape_out_sizes);
output->Resize(framework::make_ddim(reshape_out_sizes));
// Todo: transpose out
out->Resize(out_sizes) auto ret =
out->Resize(framework::make_ddim(out_sizes));
/*
auto out_ret =
TransposeSimple<T>::run(ctx, *output, reverse_dim_permute, out);
if (!ret) {
if (!out_ret) {
TransCompute<DeviceContext, T>(ndim, ctx, *output, out,
reverse_dim_permute);
}
*/
TransCompute<DeviceContext, T>(ndim, ctx, *output, out, reverse_dim_permute);

/*
std::vector<int64_t> out_strides(ndim);
Expand Down Expand Up @@ -736,7 +742,8 @@ void exec_normalization(Tensor* out, FFTNormMode normalization,
const std::vector<int64_t>& axes) {
auto scale = fft_normalization_scale(normalization, sizes, axes);
if (scale != 1.0) {
out->mul(scale);
// Todo inplace multiply scalar
// out->mul(scale);
}
}

Expand Down Expand Up @@ -796,32 +803,35 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, T> {
return;
}

auto out_dims = framework::vectorize(X->dims());
std::vector<int64_t> out_dims = framework::vectorize(X->dims());
std::vector<int64_t> working_axes(axes.begin(), axes.end());
framework::Tensor working_tensor;
working_tensor.mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);
std::vector<int64_t> first_dims;
size_t max_dims;
framework::Tensor* working_tensor;
working_tensor->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), working_tensor);

while (true) {
const auto max_dims =
max_dims =
std::min(static_cast<size_t>(kMaxCUFFTNdim), working_axes.size());
auto first_dims = std::vector<int64_t>(working_axes.end() - max_dims,
working_axes.end());
first_dims.assign(working_axes.end() - max_dims, working_axes.end());

exec_fft<CUDADeviceContext, T>(ctx, out, working_tensor, out_dims,
first_dims, forward);
exec_fft<platform::CUDADeviceContext, T>(ctx, out, working_tensor,
out_dims, first_dims, forward);
working_axes.resize(working_axes.size() - max_dims);
first_dims.clear();

if (working_axes.empty()) {
break;
}

std::swap(*out, working_tensor);
std::swap(out, working_tensor);
}
exec_normalization(out, normalization, out_dims, axes);
}
};

/*
template <typename T>
class FFTC2CKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
Expand Down Expand Up @@ -866,6 +876,8 @@ class FFTC2CGradKernel<platform::CUDADeviceContext, T>
normalization, forward);
}
};
*/

} // namespace operators
} // namespace paddle

Expand Down
10 changes: 4 additions & 6 deletions paddle/fluid/operators/spectral_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand All @@ -26,7 +28,6 @@ enum class FFTNormMode : int64_t {

FFTNormMode get_norm_from_string(const std::string& norm, bool forward);

/*
template <typename DeviceContext, typename T>
struct FFTC2CFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
Expand Down Expand Up @@ -89,7 +90,6 @@ class FFTC2CGradKernel : public framework::OpKernel<T> {
fft_c2c_func(dev_ctx, dy, dx, axes, normalization, forward);
}
};
*/

// Enum representing the FFT type
enum class FFTTransformType : int8_t {
Expand Down Expand Up @@ -126,8 +126,7 @@ inline bool has_complex_input(FFTTransformType type) {
case FFTTransformType::R2C:
return false;
}
PADDLE_THROW(
platform::errors::InvalidArgument("Real to real FFTs are not supported"));
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}

// Returns true if the transform type has complex output
Expand All @@ -140,8 +139,7 @@ inline bool has_complex_output(FFTTransformType type) {
case FFTTransformType::C2R:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown FFTTransformType : [%s]", type));
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}

template <typename DeviceContext, typename T>
Expand Down

0 comments on commit 96e8b09

Please sign in to comment.