Skip to content

Commit

Permalink
[CINN] Integrate CUTLASS into CINN [Part1] (PaddlePaddle#58079)
Browse files Browse the repository at this point in the history
* [CINN] Integrate CUTLASS into CINN [Part1]

* fix cmake
  • Loading branch information
ZzSean authored and SecretXV committed Nov 28, 2023
1 parent 0fb898f commit f4ef02d
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
14 changes: 14 additions & 0 deletions cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ if(WITH_GPU)
message(STATUS "Enable CINN CUDNN")
add_definitions(-DCINN_WITH_CUDNN)
endif()
if(WITH_CUTLASS)
message(STATUS "Enable CINN CUTLASS")
add_definitions(-DCINN_WITH_CUTLASS)
endif()
enable_language(CUDA)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
Expand Down Expand Up @@ -199,6 +203,11 @@ if(WITH_GPU)
endif()
endif()

if(WITH_CUTLASS)
target_link_libraries(cinnapi cutlass)
add_dependencies(cinnapi cutlass)
endif()

function(gen_cinncore LINKTYPE)
set(CINNCORE_TARGET cinncore)
if(${LINKTYPE} STREQUAL "STATIC")
Expand Down Expand Up @@ -258,6 +267,11 @@ function(gen_cinncore LINKTYPE)
target_link_libraries(${CINNCORE_TARGET} ${CUDA_NVTX_LIB})
endif()
endif()

if(WITH_CUTLASS)
target_link_libraries(cinnapi cutlass)
add_dependencies(cinnapi cutlass)
endif()
endfunction()

gen_cinncore(STATIC)
Expand Down
16 changes: 11 additions & 5 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog

########################### include third_party according to flags ###############################
if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
endif()

if(WITH_CINN)
if(WITH_MKL)
add_definitions(-DCINN_WITH_MKL_CBLAS)
Expand Down Expand Up @@ -555,11 +566,6 @@ if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/frontend/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ PD_DECLARE_bool(use_reduce_split_pass);
PD_DECLARE_bool(cinn_use_dense_merge_pass);
PD_DECLARE_string(cinn_custom_call_deny_ops);
PD_DECLARE_bool(general_fusion_merge_pass);
PD_DECLARE_bool(cinn_use_cutlass);

namespace cinn {
namespace frontend {
Expand All @@ -58,6 +59,7 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
return FLAGS_cinn_custom_call_deny_ops.find(op) != std::string::npos;
};
bool is_gemm_use_cublas = FLAGS_cinn_use_custom_call &&
!FLAGS_cinn_use_cutlass &&
!can_find_custom_call_deny_op("matmul") &&
!can_find_custom_call_deny_op("cublas_gemm") &&
!can_find_custom_call_deny_op("cublas_matmul");
Expand Down
7 changes: 5 additions & 2 deletions paddle/cinn/hlir/pass/custom_call_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/cinn/utils/string.h"

PD_DECLARE_string(cinn_custom_call_deny_ops);
PD_DECLARE_bool(cinn_use_cutlass);

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -72,8 +73,10 @@ class GraphAlterHelper {
}
}

node->attrs.attr_store["original_op"] = node->op()->name;
node->attrs.op = framework::Operator::Get("custom_call");
if (!FLAGS_cinn_use_cutlass || node->op()->name != "matmul") {
node->attrs.attr_store["original_op"] = node->op()->name;
node->attrs.op = framework::Operator::Get("custom_call");
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ PD_DEFINE_double(cinn_infer_model_version,
"Paddle has different model format in inference model. We use "
"a flag to load different versions.");

PD_DEFINE_bool(cinn_use_cutlass,
BoolFromEnv("FLAGS_cinn_use_cutlass", false),
"Whether to use cutlass kernels");

namespace cinn {
namespace runtime {

Expand Down

0 comments on commit f4ef02d

Please sign in to comment.