Skip to content

Commit

Permalink
[CMAKE][CUTLASS] Improve dependancy management with different cutlass…
Browse files Browse the repository at this point in the history
… versions. (apache#47)

* Each cutlass-based submodule library now uses its own cutlass submodule dependancy
 * TVM's cutlass submodule is decoupled from others and is bumped to
 v3.4.1 for H100 support
 * Add scaffold for new cutlass fp8 dequant gemm interface targetting
 TVM's cutlass submodule

Co-authored-by: Chris Sullivan <csullivan@octo.ai>
  • Loading branch information
csullivan and csullivan authored Feb 26, 2024
1 parent 1fa1c24 commit f4b0c28
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 13 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 1843 files
27 changes: 23 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
)
set(TVM_RUNTIME_EXT_OBJS "")

if(BUILD_FOR_HEXAGON)
if(NOT BUILD_STATIC_RUNTIME)
Expand Down Expand Up @@ -592,26 +593,44 @@ add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE})

include(GNUInstallDirs)
if(NOT BUILD_DUMMY_LIBTVM)
add_library(tvm SHARED $<TARGET_OBJECTS:tvm_objs> $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm SHARED
$<TARGET_OBJECTS:tvm_objs>
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)

else()
# dummy version of libtvm that can be used by downstream to specify dependencies
# the real runner still need a full version of libtvm
add_library(tvm SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm SHARED
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)
endif()

target_include_directories(tvm PUBLIC "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>")
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}")
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}")
if(BUILD_STATIC_RUNTIME)
add_library(tvm_runtime STATIC $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm_runtime STATIC
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)
set(NOTICE_MULTILINE
"You have build static version of the TVM runtime library. Make "
"sure to use --whole-archive when linking it into your project.")
string(CONCAT NOTICE ${NOTICE_MULTILINE})
add_custom_command(TARGET tvm_runtime POST_BUILD
COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE})
else()
add_library(tvm_runtime SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm_runtime SHARED
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)
set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}")
endif()

Expand Down
56 changes: 48 additions & 8 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,61 @@
# under the License.

if(USE_CUDA AND USE_CUTLASS)
tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
set(CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA}>,$<BOOL:${USE_CUTLASS}>>")
set(CUTLASS_RUNTIME_OBJS "")

tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
src/relay/backend/contrib/cutlass/*.cc
src/relax/backend/contrib/cutlass/*.cc
)
list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})

set(FPA_INTB_GEMM_TVM_BINDING ON)
set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR})

set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
### Build cutlass runtime objects for fpA_intB_gemm using its cutlass submodule
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
target_include_directories(fpA_intB_gemm PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
)
set(CUTLASS_FPA_INTB_RUNTIME_SRCS "")
list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS})
target_include_directories(fpA_intB_cutlass_objs PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>")

### Build cutlass runtime objects for flash attention using its cutlass submodule
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
target_include_directories(flash_attn PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
)
set(CUTLASS_FLASH_ATTN_RUNTIME_SRCS "")
list(APPEND CUTLASS_FLASH_ATTN_RUNTIME_SRCS src/runtime/contrib/cutlass/flash_decoding.cu)
add_library(flash_attn_cutlass_objs OBJECT ${CUTLASS_FLASH_ATTN_RUNTIME_SRCS})
target_include_directories(flash_attn_cutlass_objs PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:flash_attn_cutlass_objs>>")

### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
set(TVM_CUTLASS_RUNTIME_SRCS "")
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_fp8_gemm.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
endif()

include_directories(3rdparty/cutlass_fpA_intB_gemm
3rdparty/cutlass_fpA_intB_gemm/cutlass/include) # FIXME
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/flash_decoding.cu)
### Add cutlass objects to list of TVM runtime extension objs
list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")

message(STATUS "Build with CUTLASS")
endif()
endif()
25 changes: 25 additions & 0 deletions src/runtime/contrib/cutlass/fp16_fp8_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

TVM_REGISTER_GLOBAL("cutlass.fp16_fp8_gemm").set_body_typed([]() { return 0; });

0 comments on commit f4b0c28

Please sign in to comment.