From 643c62848b875468af55f2a1848ee525667b9082 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sun, 7 Apr 2024 16:10:02 -0700 Subject: [PATCH] Revert "Use new API to register custom ops for llama model (#2840)" (#2912) Summary: This reverts commit 020d8bee8fddf837006c4e35bf7dff5278df2e24. Pull Request resolved: https://github.com/pytorch/executorch/pull/2912 Reviewed By: shoumikhin Differential Revision: D55852547 Pulled By: larryliu0820 fbshipit-source-id: c8528041c03196239d6daef7e2843ee5cf8a8f3d --- .ci/scripts/test_llama.sh | 41 ++--- .github/workflows/pull.yml | 2 +- .github/workflows/trunk.yml | 2 +- CMakeLists.txt | 62 +++----- examples/models/llama2/CMakeLists.txt | 66 +++----- examples/models/llama2/TARGETS | 3 +- .../models/llama2/custom_ops/CMakeLists.txt | 32 ++-- .../models/llama2/custom_ops/custom_ops.yaml | 14 ++ examples/models/llama2/custom_ops/op_sdpa.cpp | 8 +- examples/models/llama2/custom_ops/op_sdpa.h | 48 ------ .../models/llama2/custom_ops/op_sdpa_aot.cpp | 107 ------------- .../models/llama2/custom_ops/op_sdpa_test.cpp | 5 +- .../custom_ops/op_sdpa_with_kv_cache_test.cpp | 4 +- .../llama2/custom_ops/sdpa_with_kv_cache.py | 118 ++++++++++---- examples/models/llama2/custom_ops/targets.bzl | 147 +++++++++++------- examples/models/llama2/llama_transformer.py | 2 +- examples/models/llama2/runner/CMakeLists.txt | 3 +- .../make_aten_functor_from_et_functor.h | 3 +- extension/aten_util/targets.bzl | 1 - extension/kernel_util/meta_programming.h | 2 +- 20 files changed, 286 insertions(+), 384 deletions(-) create mode 100644 examples/models/llama2/custom_ops/custom_ops.yaml delete mode 100644 examples/models/llama2/custom_ops/op_sdpa.h delete mode 100644 examples/models/llama2/custom_ops/op_sdpa_aot.cpp diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 8cb683228a..90ea13281b 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -37,18 +37,6 @@ if [[ -z "${MODE:-}" ]]; then exit 1 fi -if [[ "${MODE}" =~ xnnpack.* ]]; then - XNNPACK=ON -else - XNNPACK=OFF -fi - -if [[ "${MODE}" =~ .*custom.* ]]; then - CUSTOM=ON -else - CUSTOM=OFF -fi - if [[ -z "${BUCK:-}" ]]; then BUCK=buck2 fi @@ -59,23 +47,25 @@ fi which "${PYTHON_EXECUTABLE}" -CMAKE_PREFIX_PATH=$($PYTHON_EXECUTABLE -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") cmake_install_executorch_libraries() { echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a" rm -rf cmake-out + if [[ "${MODE}" == "xnnpack" ]]; then + XNNPACK=ON + else + XNNPACK=OFF + fi retry cmake -DBUCK2="$BUCK" \ -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \ - -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_CUSTOM="$CUSTOM" \ -DEXECUTORCH_BUILD_OPTIMIZED=ON \ -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \ -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ -Bcmake-out . - cmake --build cmake-out -j9 --target install --config Debug + cmake --build cmake-out -j9 --target install --config Release } cmake_build_llama_runner() { @@ -83,15 +73,12 @@ cmake_build_llama_runner() { dir="examples/models/llama2" retry cmake -DBUCK2="$BUCK" \ -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \ - -DCMAKE_BUILD_TYPE=Debug \ - -DEXECUTORCH_BUILD_CUSTOM="$CUSTOM" \ + -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \ -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ -Bcmake-out/${dir} \ ${dir} - cmake --build cmake-out/${dir} -j9 --config Debug + cmake --build cmake-out/${dir} -j9 --config Release } @@ -126,20 +113,13 @@ else exit 1 fi -# Install custom ops before exporting -echo "Installing executorch libraries" -cmake_install_executorch_libraries - # Export model. EXPORTED_MODEL_NAME="${EXPORTED_MODEL_NAME}.pte" echo "Exporting ${EXPORTED_MODEL_NAME}" EXPORT_ARGS="-c stories110M.pt -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME}" -if [[ "${MODE}" == "xnnpack+kv+custom" ]]; then +if [[ "${MODE}" == "xnnpack" ]]; then EXPORT_ARGS="${EXPORT_ARGS} -kv --use_sdpa_with_kv_cache -X -qmode 8da4w -G 128" fi -# Add dynamically linked library location -export LD_LIBRARY_PATH=${PWD}/cmake-out/lib -export DYLD_LIBRARY_PATH=${PWD}/cmake-out/lib $PYTHON_EXECUTABLE -m examples.models.llama2.export_llama ${EXPORT_ARGS} # Create tokenizer.bin. @@ -155,6 +135,7 @@ if [[ "${BUILD_TOOL}" == "buck2" ]]; then # shellcheck source=/dev/null $BUCK run examples/models/llama2:main -- ${RUNTIME_ARGS} > result.txt elif [[ "${BUILD_TOOL}" == "cmake" ]]; then + cmake_install_executorch_libraries cmake_build_llama_runner # Run llama runner NOW=$(date +"%H:%M:%S") diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index f2cc83693c..9751b906cd 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -90,7 +90,7 @@ jobs: matrix: dtype: [fp32] build-tool: [buck2, cmake] - mode: [portable, xnnpack+kv+custom] + mode: [portable, xnnpack] fail-fast: false with: runner: linux.2xlarge diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 0be28c40c7..16ed6a2757 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -254,7 +254,7 @@ jobs: matrix: dtype: [fp32] build-tool: [buck2, cmake] - mode: [portable, xnnpack+kv+custom] + mode: [portable, xnnpack] fail-fast: false with: runner: macos-m1-stable diff --git a/CMakeLists.txt b/CMakeLists.txt index ab790e9ea3..46b73f6349 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,9 +175,8 @@ option(EXECUTORCH_BUILD_VULKAN "Build the Vulkan backend" OFF) # # pthreadpool: build pthreadpool library. Disable on unsupported platforms # -cmake_dependent_option( - EXECUTORCH_BUILD_PTHREADPOOL "Build pthreadpool library." ON - "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF) +cmake_dependent_option(EXECUTORCH_BUILD_PTHREADPOOL "Build pthreadpool library." + ON "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF) # # cpuinfo: build cpuinfo library. Disable on unsupported platforms @@ -187,9 +186,6 @@ cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON if(EXECUTORCH_BUILD_CPUINFO) # --- cpuinfo - set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG - ${CMAKE_POSITION_INDEPENDENT_CODE}) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo") set(CPUINFO_BUILD_TOOLS OFF @@ -211,15 +207,10 @@ if(EXECUTORCH_BUILD_CPUINFO) CACHE STRING "") set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog") add_subdirectory("${CPUINFO_SOURCE_DIR}") - set(CMAKE_POSITION_INDEPENDENT_CODE - ${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) endif() if(EXECUTORCH_BUILD_PTHREADPOOL) # --- pthreadpool - set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG - ${CMAKE_POSITION_INDEPENDENT_CODE}) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool") set(PTHREADPOOL_BUILD_TESTS OFF @@ -239,8 +230,6 @@ if(EXECUTORCH_BUILD_PTHREADPOOL) CACHE STRING "") endif() add_subdirectory("${PTHREADPOOL_SOURCE_DIR}") - set(CMAKE_POSITION_INDEPENDENT_CODE - ${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) endif() if(NOT PYTHON_EXECUTABLE) @@ -515,38 +504,25 @@ if(EXECUTORCH_BUILD_PYBIND) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sdk) endif() - # find pytorch lib, to allow pybind to take at::Tensor as input/output - find_package(Torch CONFIG REQUIRED) - find_library(TORCH_PYTHON_LIBRARY torch_python - PATHS "${TORCH_INSTALL_PREFIX}/lib") - - set(_dep_libs - ${TORCH_PYTHON_LIBRARY} - bundled_program - etdump - executorch - extension_data_loader - portable_ops_lib - util - torch) - if(EXECUTORCH_BUILD_COREML) - list(APPEND _dep_libs coremldelegate) + set(PYBIND_LINK_COREML "coremldelegate") endif() if(EXECUTORCH_BUILD_MPS) - list(APPEND _dep_libs mpsdelegate) + set(PYBIND_LINK_MPS "mpsdelegate") endif() if(EXECUTORCH_BUILD_XNNPACK) - # need to explicitly specify XNNPACK here otherwise uses XNNPACK symbols - # from libtorch_cpu - list(APPEND _dep_libs xnnpack_backend XNNPACK) + # need to explicitly specify XNNPACK here + # otherwise uses XNNPACK symbols from libtorch_cpu + set(PYBIND_LINK_XNNPACK xnnpack_backend XNNPACK) endif() - if(EXECUTORCH_BUILD_CUSTOM) - list(APPEND _dep_libs custom_ops custom_ops_aot_lib) - endif() + # find pytorch lib, to allow pybind to take at::Tensor as input/output + find_package(Torch CONFIG REQUIRED) + find_library(TORCH_PYTHON_LIBRARY torch_python + PATHS "${TORCH_INSTALL_PREFIX}/lib") + # compile options for pybind set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti @@ -568,7 +544,19 @@ if(EXECUTORCH_BUILD_PYBIND) PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib) target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS}) target_compile_options(portable_lib PUBLIC ${_pybind_compile_options}) - target_link_libraries(portable_lib PUBLIC ${_dep_libs}) + target_link_libraries( + portable_lib + PUBLIC ${TORCH_PYTHON_LIBRARY} + bundled_program + etdump + executorch + extension_data_loader + portable_ops_lib + util + torch + ${PYBIND_LINK_COREML} + ${PYBIND_LINK_MPS} + ${PYBIND_LINK_XNNPACK}) install(TARGETS portable_lib LIBRARY DESTINATION executorch/extension/pybindings) diff --git a/examples/models/llama2/CMakeLists.txt b/examples/models/llama2/CMakeLists.txt index abe0bbfdc5..ea4096074e 100644 --- a/examples/models/llama2/CMakeLists.txt +++ b/examples/models/llama2/CMakeLists.txt @@ -49,16 +49,22 @@ set(_common_compile_options -Wno-deprecated-declarations -fPIC) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) -# For some reason android build is not able to find where gflags is and hence -# cannot find corresponding .cmake file +# For some reason android build is not able to find where gflags is +# and hence cannot find corresponding .cmake file set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) find_package(gflags REQUIRED) # # llama_main: test binary to run llama, with tokenizer and sampler integrated # +add_executable(llama_main main.cpp +${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/threadpool/cpuinfo_utils.cpp) +if(CMAKE_BUILD_TYPE EQUAL "RELEASE") + target_link_options(llama_main PRIVATE "LINKER:--gc-sections") +endif() -# find `executorch` libraries Same as for gflags +# find `executorch` libraries +# Same as for gflags set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) if(CMAKE_TOOLCHAIN_IOS OR ANDROID) @@ -66,55 +72,33 @@ if(CMAKE_TOOLCHAIN_IOS OR ANDROID) endif() # custom ops library -if(EXECUTORCH_BUILD_CUSTOM) - add_subdirectory(custom_ops) -endif() +add_subdirectory(custom_ops) # llama_runner library add_subdirectory(runner) +target_include_directories(llama_main PUBLIC +${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/cpuinfo/include) +target_include_directories(llama_main PUBLIC +${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/pthreadpool/include) + set(link_libraries) -set(_srcs main.cpp) if(EXECUTORCH_BUILD_OPTIMIZED) - list( - APPEND - link_libraries - optimized_native_cpu_ops_lib - optimized_kernels - portable_kernels - cpublas - eigen_blas) + list(APPEND link_libraries optimized_native_cpu_ops_lib optimized_kernels + portable_kernels cpublas eigen_blas) target_link_options_shared_lib(optimized_native_cpu_ops_lib) else() list(APPEND link_libraries portable_ops_lib portable_kernels) target_link_options_shared_lib(portable_ops_lib) endif() -if(EXECUTORCH_BUILD_CUSTOM) - target_link_options_shared_lib(custom_ops) - list(APPEND link_libraries custom_ops) -endif() +target_link_libraries(llama_main PUBLIC gflags llama_runner custom_ops_lib) # XNNPACK pthreadpool cpuinfo if(TARGET xnnpack_backend) set(xnnpack_backend_libs xnnpack_backend XNNPACK pthreadpool cpuinfo) list(APPEND link_libraries ${xnnpack_backend_libs}) - # HACK: main only include these when xnnpack backend is availabe, so that we - # have all the threadpool sources under xnnpack. - list(APPEND _common_compile_options -DET_USE_THREADPOOL) - list( - APPEND - _srcs - ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/threadpool/cpuinfo_utils.cpp - ) - list( - APPEND - _common_include_directories - ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/cpuinfo/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack/third-party/pthreadpool/include - ) - # end of hack target_link_options_shared_lib(xnnpack_backend) endif() @@ -130,19 +114,15 @@ if(TARGET qnn_executorch_backend) target_link_options_shared_lib(qnn_executorch_backend) endif() -# This one is needed for cpuinfo where it uses android specific log lib +# This one is needed for cpuinfo where it uses android +# specific log lib if(ANDROID) list(APPEND link_libraries log) endif() -add_executable(llama_main ${_srcs}) -if(CMAKE_BUILD_TYPE EQUAL "RELEASE") - target_link_options(llama_main PRIVATE "LINKER:--gc-sections") -endif() - -target_include_directories(llama_main PUBLIC ${_common_include_directories}) -target_link_libraries(llama_main PUBLIC gflags llama_runner ${link_libraries}) -target_compile_options(llama_main PUBLIC ${_common_compile_options}) +target_compile_options(llama_main PUBLIC ${_common_compile_options} + -DET_USE_THREADPOOL) +target_link_libraries(llama_main PUBLIC ${link_libraries}) if(APPLE) target_link_options_shared_lib(executorch) diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 09ebd5aead..c93ea6149f 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -18,7 +18,7 @@ runtime.python_library( ], deps = [ "//caffe2:torch", - "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py", + "//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib", ], ) @@ -52,7 +52,6 @@ runtime.python_binary( main_module = "executorch.examples.models.llama2.export_llama", # visibility = ["//executorch/examples/..."], preload_deps = [ - "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_lib", "//executorch/kernels/quantized:aot_lib", ], deps = [ diff --git a/examples/models/llama2/custom_ops/CMakeLists.txt b/examples/models/llama2/custom_ops/CMakeLists.txt index dbda13363b..d06f3d5de8 100644 --- a/examples/models/llama2/custom_ops/CMakeLists.txt +++ b/examples/models/llama2/custom_ops/CMakeLists.txt @@ -25,7 +25,7 @@ if(NOT TORCH_ROOT) set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) endif() -set(_common_compile_options -Wno-deprecated-declarations -fPIC) +set(_common_compile_options -Wno-deprecated-declarations) include(${EXECUTORCH_ROOT}/build/Utils.cmake) include(${EXECUTORCH_ROOT}/build/Codegen.cmake) @@ -44,12 +44,21 @@ include(${EXECUTORCH_SRCS_FILE}) set(_common_include_directories ${EXECUTORCH_ROOT}/..) # Custom op libraries -set(custom_ops_libs executorch) +set(custom_ops_libs extension_module) list(APPEND custom_ops_libs pthreadpool) list(APPEND custom_ops_libs cpuinfo) list(APPEND custom_ops_libs cpublas) list(APPEND custom_ops_libs eigen_blas) +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in optimized.yaml +set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml") +gen_selected_ops("${_yaml}" "" "") + +generate_bindings_for_kernels(FUNCTIONS_YAML + ${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml) +message("Generated files ${gen_command_sources}") + list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/") # TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used @@ -61,8 +70,6 @@ if(NOT EXECUTORCH_BUILD_XNNPACK) "${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp" ) -else() - list(APPEND custom_ops_libs xnnpack_backend) endif() add_library(custom_ops ${_custom_ops__srcs}) @@ -75,16 +82,7 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs}) target_compile_options(custom_ops PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL) -# Add a AOT library -find_package(Torch CONFIG REQUIRED) -add_library(custom_ops_aot_lib SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp) -target_include_directories(custom_ops_aot_lib - PUBLIC "${_common_include_directories}") -target_include_directories( - custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include") -target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch) -target_compile_options(custom_ops_aot_lib PUBLIC -Wno-deprecated-declarations - -fPIC -frtti -fexceptions) - -install(TARGETS custom_ops custom_ops_aot_lib DESTINATION lib) +# Build a library for _custom_ops_srcs +# +# custom_ops_lib: Register optimized ops kernels into Executorch runtime +gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch) diff --git a/examples/models/llama2/custom_ops/custom_ops.yaml b/examples/models/llama2/custom_ops/custom_ops.yaml new file mode 100644 index 0000000000..8de14c6aaa --- /dev/null +++ b/examples/models/llama2/custom_ops/custom_ops.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This yaml file contains operators that have optimized kernels available. + +- func: llama::sdpa.out(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::flash_attention_kernel_out + +- func: llama::sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, Tensor(b!) value_cache, int start_pos, int seq_len, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!) + kernels: + - arg_meta: null + kernel_name: torch::executor::sdpa_with_kv_cache_out diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index bf8c31de73..18e24eb867 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include #include @@ -22,7 +22,6 @@ #include #include #endif -#include namespace torch { namespace executor { @@ -844,8 +843,3 @@ Tensor& sdpa_with_kv_cache_out( } // namespace native } // namespace executor } // namespace torch - -EXECUTORCH_LIBRARY( - llama, - "sdpa_with_kv_cache.out", - torch::executor::native::sdpa_with_kv_cache_out); diff --git a/examples/models/llama2/custom_ops/op_sdpa.h b/examples/models/llama2/custom_ops/op_sdpa.h deleted file mode 100644 index fd130964eb..0000000000 --- a/examples/models/llama2/custom_ops/op_sdpa.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -namespace torch { -namespace executor { - -namespace native { - -Tensor& sdpa_with_kv_cache_out( - RuntimeContext& ctx, - const Tensor& q_projected, - const Tensor& k_projected, - const Tensor& v_projected, - Tensor& key_cache, - Tensor& value_cache, - const int64_t start_pos, - const int64_t seq_len, - const optional& attn_mask, - const double dropout_p, - const bool is_causal, - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const optional scale, - Tensor& output); - -Tensor& flash_attention_kernel_out( - RuntimeContext& ctx, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const optional& attn_mask, - const double dropout_p, - const bool is_causal, - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const optional scale, - Tensor& output); - -} // namespace native -} // namespace executor -} // namespace torch diff --git a/examples/models/llama2/custom_ops/op_sdpa_aot.cpp b/examples/models/llama2/custom_ops/op_sdpa_aot.cpp deleted file mode 100644 index ed735406ad..0000000000 --- a/examples/models/llama2/custom_ops/op_sdpa_aot.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include - -namespace torch { -namespace executor { - -namespace native { - -Tensor& sdpa_with_kv_cache_out_no_context( - const Tensor& q_projected, - const Tensor& k_projected, - const Tensor& v_projected, - Tensor& key_cache, - Tensor& value_cache, - const int64_t start_pos, - const int64_t seq_len, - // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const optional attn_mask, - const double dropout_p, - const bool is_causal, - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const optional scale, - Tensor& output) { - exec_aten::RuntimeContext context{}; - return torch::executor::native::sdpa_with_kv_cache_out( - context, - q_projected, - k_projected, - v_projected, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - dropout_p, - is_causal, - scale, - output); -} - -at::Tensor sdpa_with_kv_cache_aten( - const at::Tensor& q_projected, - const at::Tensor& k_projected, - const at::Tensor& v_projected, - at::Tensor& key_cache, - at::Tensor& value_cache, - const int64_t start_pos, - const int64_t seq_len, - // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const c10::optional attn_mask, - const double dropout_p, - const bool is_causal, - // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const c10::optional scale) { - auto output = at::empty_like(q_projected); - WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) - (q_projected, - k_projected, - v_projected, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - dropout_p, - is_causal, - scale, - output); - return output; -} - -} // namespace native -} // namespace executor -} // namespace torch - -TORCH_LIBRARY(llama, m) { - m.def( - "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " - "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " - "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"); - m.def( - "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " - "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " - "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); -} - -TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { - m.impl( - "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); - m.impl( - "sdpa_with_kv_cache.out", - WRAP_TO_ATEN( - torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); -} diff --git a/examples/models/llama2/custom_ops/op_sdpa_test.cpp b/examples/models/llama2/custom_ops/op_sdpa_test.cpp index 971e8cf45c..293359d19c 100644 --- a/examples/models/llama2/custom_ops/op_sdpa_test.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa_test.cpp @@ -8,8 +8,7 @@ #include -#include - +#include // Declares the operator #include #include #include @@ -29,7 +28,7 @@ exec_aten::Tensor op_scaled_dot_product_attention( exec_aten::optional scale, exec_aten::Tensor& out) { exec_aten::RuntimeContext context{}; - return torch::executor::native::flash_attention_kernel_out( + return torch::executor::llama::sdpa_outf( context, query, key, value, attn_mask, dropout_p, is_causal, scale, out); } diff --git a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp index fa2d164fe3..6ec6f42926 100644 --- a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp @@ -8,7 +8,7 @@ #include -#include // Declares the operator +#include // Declares the operator #include #include #include @@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache( exec_aten::optional scale, exec_aten::Tensor& out) { exec_aten::RuntimeContext context{}; - return torch::executor::native::sdpa_with_kv_cache_out( + return torch::executor::llama::sdpa_with_kv_cache_outf( context, query, key, diff --git a/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py index e02bc0a367..5f11defb11 100644 --- a/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py @@ -4,36 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch -# C++ APIs for registration so here we need to import the shared library. -# This is only needed for OSS. - -import os -from ctypes.util import find_library - import torch +from torch.library import impl, impl_abstract -from torch.library import impl - -try: - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None -except: - # assuming we only hit this in OSS, find the default install path - full_name = find_library("custom_ops_aot_lib") - ld_library_path = os.environ.get("LD_LIBRARY_PATH", None) - assert ( - full_name and ld_library_path - ), f"custom_ops_aot_lib does not exist, please set LD_LIBRARY_PATH: {ld_library_path} correctly" - # find the true path - for p in ld_library_path.split(":"): - full_path = os.path.join(p, full_name) - if os.path.exists(full_path): - torch.ops.load_library(full_path) - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None +custom_ops_lib = torch.library.Library("llama", "DEF") +custom_ops_lib.define( + "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " + "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " + "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor" +) -custom_ops_lib = torch.library.Library("llama", "IMPL") +custom_ops_lib.define( + "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " + "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " + "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)" +) def _validate_params( @@ -133,3 +118,82 @@ def sdpa_with_kv_cache_meta( ) return torch.empty_like(query) + + +@impl(custom_ops_lib, "sdpa_with_kv_cache", "CompositeExplicitAutograd") +def sdpa_with_kv_cache( + query, + key, + value, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask=None, + drpout_p=0.0, + is_causal=False, + scale=None, +): + _validate_params( + query, + key, + value, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + ) + + if attn_mask is not None: + attn_mask = attn_mask[start_pos].view((1, -1)) + attn_mask = attn_mask[:, : start_pos + seq_len] + q = query.transpose(1, 2) + key_cache[:, start_pos] = key + value_cache[:, start_pos] = value + + sliced_k_cache = key_cache + sliced_v_cache = value_cache + sliced_k_cache = sliced_k_cache[:, : start_pos + seq_len, :, :] + sliced_v_cache = sliced_v_cache[:, : start_pos + seq_len, :, :] + sliced_k_cache = sliced_k_cache.transpose(1, 2) + sliced_v_cache = sliced_v_cache.transpose(1, 2) + out = torch.nn.functional.scaled_dot_product_attention( + q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask + ) + out = out.transpose(1, 2) + return out + + +@impl_abstract("llama::sdpa_with_kv_cache.out") +def sdpa_with_kv_cache_out( + query, + key, + value, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + out, +): + out = sdpa_with_kv_cache_meta( + query, + key, + value, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + ) + return out diff --git a/examples/models/llama2/custom_ops/targets.bzl b/examples/models/llama2/custom_ops/targets.bzl index 47ab799c49..ab611125fd 100644 --- a/examples/models/llama2/custom_ops/targets.bzl +++ b/examples/models/llama2/custom_ops/targets.bzl @@ -1,4 +1,41 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib") +load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper") + +def define_tests(): + codegen_function_header_wrapper("executorch/examples/models/llama2/custom_ops", "custom_ops") + + # In the long run we should really have aten variant available as well + deps = [":function_header_wrapper_custom_ops"] + generated_lib_and_op_deps = [ + ":custom_ops", + ":sdpa", + ":custom_ops_headers", + ] + runtime.cxx_test( + name = "op_sdpa_test", + srcs = [ + "op_sdpa_test.cpp", + ], + visibility = ["//executorch/..."], + deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/kernels/test:test_util", + ] + generated_lib_and_op_deps + deps, + ) + runtime.cxx_test( + name = "op_sdpa_with_kv_cache_test", + srcs = [ + "op_sdpa_with_kv_cache_test.cpp", + ], + visibility = ["//executorch/..."], + deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/kernels/test:test_util", + ] + generated_lib_and_op_deps + deps, + ) def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -7,83 +44,85 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ - runtime.cxx_library( - name = "custom_ops", - srcs = ["op_sdpa.cpp"], - exported_headers = ["op_sdpa.h"], - exported_deps = [ - "//executorch/runtime/kernel:kernel_includes", - "//executorch/kernels/portable/cpu:scalar_utils", - "//executorch/kernels/optimized:libblas", - "//executorch/kernels/optimized:libvec", - "//executorch/extension/kernel_util:kernel_util", - "//executorch/extension/parallel:thread_parallel", - "//executorch/backends/xnnpack/threadpool:threadpool", + runtime.python_library( + name = "llama_custom_ops_aot_lib", + srcs = [ + "sdpa_with_kv_cache.py", ], - compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"], visibility = [ "//executorch/...", - "//executorch/examples/models/llama2/custom_ops/...", "@EXECUTORCH_CLIENTS", ], - # @lint-ignore BUCKLINT link_whole - link_whole = True, - force_static = True, + deps = [ + "//caffe2:torch", + ], ) - runtime.cxx_library( - name = "custom_ops_aot_lib", - srcs = [ - "op_sdpa_aot.cpp", - ], + runtime.export_file( + name = "custom_ops.yaml", visibility = [ "//executorch/...", "@EXECUTORCH_CLIENTS", ], - external_deps = [ - "libtorch", + ) + + # ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~ + et_operator_library( + name = "sdpa_op", + ops = [ + "llama::sdpa.out", ], - deps = [ - ":custom_ops", - "//executorch/extension/aten_util:aten_bridge", + define_static_targets = True, + visibility = [ + "//executorch/codegen/...", + "@EXECUTORCH_CLIENTS", ], ) - runtime.python_library( - name = "custom_ops_aot_py", - srcs = [ - "sdpa_with_kv_cache.py", + et_operator_library( + name = "sdpa_with_kv_cache", + ops = [ + "llama::sdpa_with_kv_cache.out", ], - visibility = ["//executorch/..."], - deps = [ - "//caffe2:torch", + define_static_targets = True, + visibility = [ + "//executorch/codegen/...", + "@EXECUTORCH_CLIENTS", ], ) - runtime.cxx_test( - name = "op_sdpa_test", - srcs = [ - "op_sdpa_test.cpp", - ], - visibility = ["//executorch/..."], + runtime.cxx_library( + name = "sdpa", + srcs = ["op_sdpa.cpp"], deps = [ - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - "//executorch/kernels/test:test_util", - ":custom_ops", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/kernels/optimized:libblas", + "//executorch/kernels/optimized:libvec", + "//executorch/extension/parallel:thread_parallel", + "//executorch/backends/xnnpack/threadpool:threadpool", + ], + compiler_flags = ["-Wno-missing-prototypes"], + visibility = [ + "//executorch/...", + "//executorch/examples/models/llama2/custom_ops/...", + "@EXECUTORCH_CLIENTS", ], + force_static = True, ) - runtime.cxx_test( - name = "op_sdpa_with_kv_cache_test", - srcs = [ - "op_sdpa_with_kv_cache_test.cpp", - ], - visibility = ["//executorch/..."], + executorch_generated_lib( + name = "custom_ops", deps = [ - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - "//executorch/kernels/test:test_util", - ":custom_ops", + ":sdpa_op", + ":sdpa_with_kv_cache", + ":sdpa", + ], + custom_ops_yaml_target = ":custom_ops.yaml", + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", ], + define_static_targets = True, ) + define_tests() diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index e9650f8181..2a259af59c 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -277,7 +277,7 @@ def forward( y = self.wo(y) return y else: - from .custom_ops import sdpa_with_kv_cache # noqa + from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa output = torch.ops.llama.sdpa_with_kv_cache( q, diff --git a/examples/models/llama2/runner/CMakeLists.txt b/examples/models/llama2/runner/CMakeLists.txt index 81a80dab9c..8e9190eb4c 100644 --- a/examples/models/llama2/runner/CMakeLists.txt +++ b/examples/models/llama2/runner/CMakeLists.txt @@ -47,7 +47,8 @@ else() add_library(llama_runner SHARED ${_llama_runner__srcs}) endif() -set(llama_runner_deps executorch extension_module extension_data_loader) +set(llama_runner_deps executorch extension_module extension_data_loader + custom_ops) target_link_libraries( llama_runner PUBLIC ${llama_runner_deps}) diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index 92d19c0484..976549af8d 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -149,7 +149,8 @@ struct type_convert< } c10::ScalarType scalar_type = static_cast(val.scalar_type()); - converted = at::from_blob(val.mutable_data_ptr(), sizes, scalar_type); + converted = + at::from_blob(val.mutable_data_ptr(), val.numel(), sizes, scalar_type); } ATensor call() { return converted; diff --git a/extension/aten_util/targets.bzl b/extension/aten_util/targets.bzl index 6e32583029..b396cb7832 100644 --- a/extension/aten_util/targets.bzl +++ b/extension/aten_util/targets.bzl @@ -27,7 +27,6 @@ def define_common_targets(): ], exported_deps = [ "//executorch/extension/kernel_util:kernel_util", - "//executorch/extension/runner_util:managed_tensor", "//executorch/runtime/core:core", "//executorch/runtime/core:evalue", "//executorch/runtime/core/exec_aten:lib", diff --git a/extension/kernel_util/meta_programming.h b/extension/kernel_util/meta_programming.h index c412e907ea..46262b843e 100644 --- a/extension/kernel_util/meta_programming.h +++ b/extension/kernel_util/meta_programming.h @@ -49,7 +49,7 @@ struct is_compile_time_function_pointer< CompileTimeFunctionPointer> : std::true_type {}; #define EXECUTORCH_FN_TYPE(func) \ - ::torch::executor::CompileTimeFunctionPointer< \ + CompileTimeFunctionPointer< \ std::remove_pointer_t>, \ func> #define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()