Skip to content

Commit

Permalink
Adding CUDNN Frontend and use for CUDA NN Convolution (#19470)
Browse files Browse the repository at this point in the history
### Description
Added CUDNN Frontend and used it for NHWC convolutions, and optionally
fuse activation.

#### Backward compatible 
- For model existed with FusedConv, model can still run. 
- If ORT is built with cuDNN 8, cuDNN frontend will not be built into
binary. Old kernels (using cudnn backend APIs) are used.

#### Major Changes
- For cuDNN 9, we will enable cudnn frontend to fuse convolution and
bias when a provider option `fuse_conv_bias=1`.
- Remove the fusion of FusedConv from graph transformer for CUDA
provider, so there will not be FusedConv be added to graph for CUDA EP
in the future.
- Update cmake files regarding to cudnn settings. The search order of
CUDNN installation in build are like the following:
  * environment variable `CUDNN_PATH`
* `onnxruntime_CUDNN_HOME` cmake extra defines. If a build starts from
build.py/build.sh, user can pass it through `--cudnn_home` parameter, or
by environment variable `CUDNN_HOME` if `--cudnn_home` not used.
* cudnn python package installation directory like
python3.xx/site-packages/nvidia/cudnn
  * CUDA installation path

#### Potential Issues

- If ORT is built with cuDNN 8, FusedConv fusion is no longer done
automatically, so some model might have performance regression. If user
still wants FusedConv operator for performance reason, they can still
have multiple ways to walkaround: like use older version of onnxruntime;
or use older version of ORT to save optimized onnx, then run with latest
version of ORT. We believe that majority users have moved to cudnn 9
when 1.20 release (since the default in ORT and PyTorch is cudnn 9 for 3
months when 1.20 release), so the impact is small.
- cuDNN graph uses TF32 by default, and user cannot disable TF32 through
the use_tf32 cuda provider option. If user encounters accuracy issue
(like in testing), user has to set environment variable
`NVIDIA_TF32_OVERRIDE=0` to disable TF32. Need update the document of
use_tf32 later.

#### Follow ups
This is one of PRs that target to enable NHWC convolution in CUDA EP by
default if device supports it. There are other changes will follow up to
make it possible.
(1) Enable `prefer_nhwc` by default for device with sm >= 70. 
(2) Change `fuse_conv_bias=1` by default after more testing.
(3) Add other NHWC operators (like Resize or UpSample).

### Motivation and Context

The new CUDNN Frontend library provides the functionality to fuse
operations and provides new heuristics for kernel selection. Here it
fuses the convolution with the pointwise bias operation. On the [NVIDIA
ResNet50](https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/)
we get a performance boost from 49.1144 ms to 42.4643 ms per inference
on a 2560x1440 input (`onnxruntime_perf_test -e cuda -I -q -r 100-d 1 -i
'prefer_nhwc|1' resnet50.onnx`).

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: Maximilian Mueller <maximilianm@nvidia.com>
  • Loading branch information
3 people authored Aug 2, 2024
1 parent 0e708de commit 1391354
Show file tree
Hide file tree
Showing 45 changed files with 1,806 additions and 559 deletions.
10 changes: 10 additions & 0 deletions cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@
},
"comments": "directx_headers"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "98ca4e1941fe3263f128f74f10063a3ea35c7019",
"repositoryUrl": "https://github.com/NVIDIA/cudnn-frontend.git"
},
"comments": "cudnn_frontend"
}
}
]
}
6 changes: 0 additions & 6 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,6 @@ set(ORT_PROVIDER_FLAGS)
set(ORT_PROVIDER_CMAKE_FLAGS)

if (onnxruntime_USE_CUDA)
if (onnxruntime_USE_CUDA_NHWC_OPS)
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
endif()
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

Expand Down Expand Up @@ -1445,9 +1442,6 @@ if (onnxruntime_USE_CUDA)
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
endif()
find_package(CUDAToolkit REQUIRED)
if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
endif()
if (NOT CMAKE_CUDA_ARCHITECTURES)
if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu")
# Support for Jetson/Tegra ARM devices
Expand Down
1 change: 1 addition & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029
111 changes: 111 additions & 0 deletions cmake/external/cuDNN.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)

find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
REQUIRED
)

file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}")
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")

function(find_cudnn_library NAME)
find_library(
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
REQUIRED
)

if(${NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()


endfunction()

find_cudnn_library(cudnn)

include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
LIBRARY REQUIRED_VARS
CUDNN_INCLUDE_DIR cudnn_LIBRARY
)

if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)

message(STATUS "cuDNN: ${cudnn_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")

set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")

else()

set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")

endif()

target_include_directories(
CUDNN::cudnn_all
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn
)

if(CUDNN_MAJOR_VERSION EQUAL 8)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
)
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
find_cudnn_library(cudnn_cnn)
find_cudnn_library(cudnn_adv)
find_cudnn_library(cudnn_graph)
find_cudnn_library(cudnn_ops)
find_cudnn_library(cudnn_engines_runtime_compiled)
find_cudnn_library(cudnn_engines_precompiled)
find_cudnn_library(cudnn_heuristic)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv
CUDNN::cudnn_ops
CUDNN::cudnn_cnn
CUDNN::cudnn_graph
CUDNN::cudnn_engines_runtime_compiled
CUDNN::cudnn_engines_precompiled
CUDNN::cudnn_heuristic
)
endif()

mark_as_advanced(CUDNN_INCLUDE_DIR)
12 changes: 12 additions & 0 deletions cmake/external/cudnn_frontend.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
include(FetchContent)
FetchContent_Declare(
cudnn_frontend
URL ${DEP_URL_cudnn_frontend}
URL_HASH SHA1=${DEP_SHA1_cudnn_frontend}
)

set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_UNIT_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
FetchContent_MakeAvailable(cudnn_frontend)
16 changes: 6 additions & 10 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -587,20 +587,16 @@ endif()

message("Finished fetching external dependencies")


set(onnxruntime_LINK_DIRS )

if (onnxruntime_USE_CUDA)
#TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same
find_package(CUDAToolkit REQUIRED)
if (WIN32)
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64)
endif()
else()
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64)
endif()

if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
endif()
include(cuDNN)
endif()

if(onnxruntime_USE_SNPE)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ endif()
if(onnxruntime_USE_TENSORRT OR onnxruntime_USE_NCCL)
# TODO: for now, core framework depends on CUDA. It should be moved to TensorRT EP
# TODO: provider_bridge_ort.cc should not include nccl.h
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_CUDNN_HOME}/include PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
else()
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
endif()
Expand Down
14 changes: 9 additions & 5 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart)
else()
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
include(cudnn_frontend) # also defines CUDNN::*
if (onnxruntime_USE_CUDA_NHWC_OPS)
if(CUDNN_MAJOR_VERSION GREATER 8)
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
else()
message( WARNING "To compile with NHWC ops enabled please compile against cuDNN 9 or newer." )
endif()
endif()
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
endif()

if (onnxruntime_USE_TRITON_KERNEL)
Expand Down
5 changes: 1 addition & 4 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
if(onnxruntime_CUDA_MINIMAL)
set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
else()
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
set(trt_link_libs CUDNN::cudnn_all cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
endif()
file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h"
Expand All @@ -183,9 +183,6 @@
endif()
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()

# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA)
Expand Down
6 changes: 1 addition & 5 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,7 @@ endif()
onnxruntime_add_include_to_target(onnxruntime_pybind11_state Python::Module Python::NumPy)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT} ${pybind11_INCLUDE_DIRS})
if(onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# cudnn_home is optional for Window when cuda and cudnn are installed in the same directory.
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()
target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
endif()
if(onnxruntime_USE_CANN)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include)
Expand Down
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ set(provider_excluded_files
"rnn/rnn_impl.cu"
"rnn/rnn_impl.h"
"shared_inc/cuda_call.h"
"shared_inc/cudnn_fe_call.h"
"shared_inc/fpgeneric.h"
"cuda_allocator.cc"
"cuda_allocator.h"
Expand All @@ -171,6 +172,7 @@ set(provider_excluded_files
"cuda_utils.cu"
"cudnn_common.cc"
"cudnn_common.h"
"cudnn_fe_call.cc"
"cupti_manager.cc"
"cupti_manager.h"
"fpgeneric.cu"
Expand Down
4 changes: 1 addition & 3 deletions cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ if (onnxruntime_USE_EXTENSIONS)
endif()
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

if (onnxruntime_USE_ROCM)
target_compile_options(onnxruntime_session PRIVATE -Wno-sign-compare -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1)
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
Expand Down
7 changes: 0 additions & 7 deletions cmake/onnxruntime_training.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ endif()

target_include_directories(onnxruntime_training PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header} ${MPI_CXX_INCLUDE_DIRS})

if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_training PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_training PRIVATE ${NCCL_INCLUDE_DIRS})
endif()
Expand Down Expand Up @@ -81,9 +77,6 @@ if (onnxruntime_BUILD_UNIT_TESTS)

target_include_directories(onnxruntime_training_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header})
target_link_libraries(onnxruntime_training_runner PRIVATE nlohmann_json::nlohmann_json)
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_training_runner PUBLIC ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_training_runner PRIVATE ${NCCL_INCLUDE_DIRS})
Expand Down
8 changes: 4 additions & 4 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ function(AddTest)
if(onnxruntime_USE_CUDA)
#XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs,
# otherwise it will impact when CUDA DLLs can be unloaded.
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart)
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend)
endif()
target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
endif()

onnxruntime_add_include_to_target(${_UT_TARGET} date::date flatbuffers::flatbuffers)
target_include_directories(${_UT_TARGET} PRIVATE ${TEST_INC_DIR})
if (onnxruntime_USE_CUDA)
target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
if (onnxruntime_USE_NCCL)
target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS})
endif()
Expand Down Expand Up @@ -392,7 +392,7 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R
)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src})

if (onnxruntime_USE_CUDA_NHWC_OPS)
if (onnxruntime_USE_CUDA_NHWC_OPS AND CUDNN_MAJOR_VERSION GREATER 8)
file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc"
)
Expand Down Expand Up @@ -1498,7 +1498,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
list(APPEND custom_op_src_patterns
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu"
"${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*")
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
if (HAS_QSPECTRE)
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /Qspectre>")
endif()
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct CudaContext : public CustomOpContext {
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;
bool use_tf32 = true;
bool fuse_conv_bias = true;

void Init(const OrtKernelContext& kernel_ctx) {
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
Expand All @@ -57,6 +58,7 @@ struct CudaContext : public CustomOpContext {
kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
fuse_conv_bias = FetchResource<bool>(kernel_ctx, CudaResource::fuse_conv_bias_t);
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ struct OrtCUDAProviderOptionsV2 {
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int fuse_conv_bias = 0; // Enable CUDNN Frontend kernel fusing, results in JIT compiles
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
};
1 change: 1 addition & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
fuse_conv_bias_t
};
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace cuda {
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Conv<T, true>);
onnxruntime::cuda::Conv<T, true>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
Expand Down
Loading

0 comments on commit 1391354

Please sign in to comment.