Skip to content

Commit

Permalink
Set Tensor Core MathType for bfloat16 in conv using cudnn (#34409)
Browse files Browse the repository at this point in the history
  • Loading branch information
AshburnLee authored Aug 4, 2021
1 parent 56b7ebb commit c79fa1c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 68 deletions.
99 changes: 32 additions & 67 deletions paddle/fluid/operators/conv_cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,35 @@ void ChooseAlgo(const std::vector<PerfType>& perf_results,

using framework::ConvSearchCache;

static void SetConvMathType(const framework::ExecutionContext& ctx,
cudnnDataType_t dtype,
const platform::ConvolutionDescriptor& cdesc) {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cdesc.desc(), CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
#if CUDA_VERSION >= 11000
#if CUDNN_VERSION_MIN(8, 1, 0)
} else if (dev_ctx.GetComputeCapability() >= 80 &&
dtype == CUDNN_DATA_BFLOAT16) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cdesc.desc(), CUDNN_TENSOR_OP_MATH));
#endif // CUDNN_VERSION_MIN(8, 1, 0)
} else if (dtype == CUDNN_DATA_FLOAT && !cdesc.allow_tf32_) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cdesc.desc(), CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif
return;
}

struct ConvArgs {
cudnnHandle_t handle;
platform::TensorDescriptor idesc, odesc;
Expand Down Expand Up @@ -208,36 +237,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
algo_t algo;

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
#if CUDA_VERSION >= 11000
#if CUDNN_VERSION_MIN(8, 1, 0)
} else if (dev_ctx.GetComputeCapability() >= 80 &&
dtype == CUDNN_DATA_BFLOAT16) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
#endif // CUDNN_VERSION >= 8100
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_FMA_MATH));
VLOG(5) << "use cudnn_fma_math";
#endif // CUDA_VERSION >= 11000
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH));
VLOG(5) << "use cudnn_default_math";
}
#endif
SetConvMathType(ctx, dtype, args.cdesc);

if (!exhaustive_search && !deterministic) {
#if CUDNN_VERSION >= 7001
Expand Down Expand Up @@ -353,24 +353,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
size_t workspace_size = 0;
bool has_got_workspace_size = true;
algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
}
#endif
SetConvMathType(ctx, dtype, args.cdesc);

if (!exhaustive_search && !deterministic) {
#if CUDNN_VERSION >= 7001
Expand Down Expand Up @@ -501,25 +484,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
bool has_got_workspace_size = true;

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
}
#endif
SetConvMathType(ctx, dtype, args.cdesc);

algo_t algo;
if (!exhaustive_search && !deterministic) {
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/platform/cudnn_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,14 @@ class ConvolutionDescriptor {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc,
CUDNN_TENSOR_OP_MATH));
} else if (dtype == CUDNN_DATA_FLOAT && !allow_tf32) {
#if CUDA_VERSION >= 11000
#if CUDNN_VERSION_MIN(8, 1, 0)
} else if (dtype == CUDNN_DATA_BFLOAT16) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc,
CUDNN_TENSOR_OP_MATH));
#endif // CUDNN_VERSION_MIN(8,1,0)
} else if (dtype == CUDNN_DATA_FLOAT && !allow_tf32) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc, CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
Expand Down

0 comments on commit c79fa1c

Please sign in to comment.