Skip to content

Commit

Permalink
[NVIDIA] Set proper math type for convloution
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii Pavliuk committed Aug 9, 2023
1 parent f0d44a3 commit 8d74365
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ CUDA::DnnConvolutionDescriptor ConvolutionParamsCuDnn::MakeConvolutionDescriptor

// Enable computations on Tensor Core hardware which requires at least Volta GPU
// (compute capability 7.0).
const cudnnMathType_t math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
const cudnnMathType_t math_type = CUDNN_TENSOR_OP_MATH;
throwIfError(::cudnnSetConvolutionMathType(conv_desc.get(), math_type));
throwIfError(::cudnnSetConvolutionGroupCount(conv_desc.get(), groups_));

Expand All @@ -96,6 +96,7 @@ ConvolutionDescriptorsCuDnn::ConvolutionDescriptorsCuDnn(const CreationContext&
} else {
GetAlgo(dnnHandle);
}
throwIfError(::cudnnSetConvolutionMathType(conv_.get(), algo_perf_.mathType));
}

void ConvolutionDescriptorsCuDnn::BenchmarkOptimalAlgo(const CUDA::DnnHandle& dnnHandle,
Expand Down Expand Up @@ -305,7 +306,7 @@ CUDA::DnnConvolutionDescriptor ConvolutionBackpropDataParamsCuDnn::MakeConvoluti

// Enable computations on Tensor Core hardware which requires at least Volta GPU
// (compute capability 7.0).
const cudnnMathType_t math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
const cudnnMathType_t math_type = CUDNN_TENSOR_OP_MATH;
throwIfError(::cudnnSetConvolutionMathType(conv_desc.get(), math_type));
throwIfError(::cudnnSetConvolutionGroupCount(conv_desc.get(), groups_));

Expand All @@ -327,6 +328,7 @@ ConvolutionBackpropDataDescriptorCuDnn::ConvolutionBackpropDataDescriptorCuDnn(
} else {
GetAlgo(dnnHandle);
}
throwIfError(::cudnnSetConvolutionMathType(conv_.get(), algo_perf_.mathType));
}

void ConvolutionBackpropDataDescriptorCuDnn::BenchmarkOptimalAlgo(const CUDA::DnnHandle& dnnHandle) {
Expand Down

0 comments on commit 8d74365

Please sign in to comment.