Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(wip) add Flash decoding #5

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
17 changes: 15 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target
project("llama.cpp" C CXX)
include(CheckIncludeFileCXX)

set(FLASH_DIR ../flash-attention-cpp)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
Expand Down Expand Up @@ -351,13 +353,24 @@ if (LLAMA_CUBLAS)

find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
if (WIN32)
link_directories(${FLASH_DIR}/build/Release)
else()
link_directories(${FLASH_DIR}/build)
endif()

message(STATUS "cuBLAS found")

enable_language(CUDA)

set(GGML_HEADERS_CUDA ggml-cuda.h)
set(GGML_SOURCES_CUDA ggml-cuda.cu)

set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES}
${FLASH_DIR}
${FLASH_DIR}/fa
${FLASH_DIR}/cutlass/include)

add_compile_definitions(GGML_USE_CUBLAS)
if (LLAMA_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
Expand All @@ -379,12 +392,12 @@ if (LLAMA_CUBLAS)
if (LLAMA_STATIC)
if (WIN32)
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt flash_attn)
else ()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
endif()
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt flash_attn)
endif()

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
Expand Down
6 changes: 6 additions & 0 deletions examples/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

if (WIN32)
add_custom_command(TARGET main POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_SOURCE_DIR}/${FLASH_DIR}/build/Release/flash_attn.dll $<TARGET_FILE_DIR:main>)
else()
add_custom_command(TARGET main POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_SOURCE_DIR}/${FLASH_DIR}/build/libflash_attn.so $<TARGET_FILE_DIR:main>)
endif()
Loading
Loading