diff --git a/CMakeLists.txt b/CMakeLists.txt index e0716af6fff4f..03937e4e0658b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -382,6 +382,9 @@ endif() # Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. set(VLLM_PARENT_BUILD ON) +# Ensure the vllm/vllm_flash_attn directory exists before installation +install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) + # Make sure vllm-flash-attn install rules are nested under vllm/ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) diff --git a/setup.py b/setup.py index cc559f26c6f3f..60e31af0a8d39 100644 --- a/setup.py +++ b/setup.py @@ -258,6 +258,21 @@ def build_extensions(self) -> None: ] subprocess.check_call(install_args, cwd=self.build_temp) + def run(self): + # First, run the standard build_ext command to compile the extensions + super().run() + + # copy vllm/vllm_flash_attn/*.py from self.build_lib to current + # directory so that they can be included in the editable build + import glob + files = glob.glob( + os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py")) + for file in files: + dst_file = os.path.join("vllm/vllm_flash_attn", + os.path.basename(file)) + print(f"Copying {file} to {dst_file}") + self.copy_file(file, dst_file) + def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" diff --git a/vllm/vllm_flash_attn/.gitkeep b/vllm/vllm_flash_attn/.gitkeep new file mode 100644 index 0000000000000..e69de29bb2d1d