diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e5ddb20811..96739dae34d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -543,7 +543,7 @@ if(EXECUTORCH_BUILD_PYBIND) endif() if(EXECUTORCH_BUILD_CUSTOM) - list(APPEND _dep_libs custom_ops_lib) + list(APPEND _dep_libs custom_ops) endif() # compile options for pybind diff --git a/examples/models/llama2/CMakeLists.txt b/examples/models/llama2/CMakeLists.txt index ee75b59ea5c..ad6a2c78f9d 100644 --- a/examples/models/llama2/CMakeLists.txt +++ b/examples/models/llama2/CMakeLists.txt @@ -106,8 +106,8 @@ else() endif() if(EXECUTORCH_BUILD_CUSTOM) - target_link_options_shared_lib(custom_ops_lib) - list(APPEND link_libraries custom_ops_lib) + target_link_options_shared_lib(custom_ops) + list(APPEND link_libraries custom_ops) endif() set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack) diff --git a/examples/models/llama2/custom_ops/CMakeLists.txt b/examples/models/llama2/custom_ops/CMakeLists.txt index d06f3d5de81..d954b29f67b 100644 --- a/examples/models/llama2/custom_ops/CMakeLists.txt +++ b/examples/models/llama2/custom_ops/CMakeLists.txt @@ -50,15 +50,6 @@ list(APPEND custom_ops_libs cpuinfo) list(APPEND custom_ops_libs cpublas) list(APPEND custom_ops_libs eigen_blas) -# Generate C++ bindings to register kernels into both PyTorch (for AOT) and -# Executorch (for runtime). Here select all ops in optimized.yaml -set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml") -gen_selected_ops("${_yaml}" "" "") - -generate_bindings_for_kernels(FUNCTIONS_YAML - ${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml) -message("Generated files ${gen_command_sources}") - list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/") # TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used @@ -70,6 +61,8 @@ if(NOT EXECUTORCH_BUILD_XNNPACK) "${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp" ) +else() + list(APPEND custom_ops_libs xnnpack_backend) endif() add_library(custom_ops ${_custom_ops__srcs}) @@ -82,7 +75,4 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs}) target_compile_options(custom_ops PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL) -# Build a library for _custom_ops_srcs -# -# custom_ops_lib: Register optimized ops kernels into Executorch runtime -gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch) +install(TARGETS custom_ops DESTINATION lib) diff --git a/examples/models/llama2/custom_ops/__init__.py b/examples/models/llama2/custom_ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/llama2/custom_ops/custom_ops.yaml b/examples/models/llama2/custom_ops/custom_ops.yaml deleted file mode 100644 index 8de14c6aaaf..00000000000 --- a/examples/models/llama2/custom_ops/custom_ops.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This yaml file contains operators that have optimized kernels available. - -- func: llama::sdpa.out(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(a!) out) -> Tensor(a!) - variants: function - kernels: - - arg_meta: null - kernel_name: torch::executor::flash_attention_kernel_out - -- func: llama::sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, Tensor(b!) value_cache, int start_pos, int seq_len, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!) - kernels: - - arg_meta: null - kernel_name: torch::executor::sdpa_with_kv_cache_out diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index 18e24eb867c..bf8c31de73d 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include #include @@ -22,6 +22,7 @@ #include #include #endif +#include namespace torch { namespace executor { @@ -843,3 +844,8 @@ Tensor& sdpa_with_kv_cache_out( } // namespace native } // namespace executor } // namespace torch + +EXECUTORCH_LIBRARY( + llama, + "sdpa_with_kv_cache.out", + torch::executor::native::sdpa_with_kv_cache_out); diff --git a/examples/models/llama2/custom_ops/op_sdpa.h b/examples/models/llama2/custom_ops/op_sdpa.h new file mode 100644 index 00000000000..fd130964ebb --- /dev/null +++ b/examples/models/llama2/custom_ops/op_sdpa.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& sdpa_with_kv_cache_out( + RuntimeContext& ctx, + const Tensor& q_projected, + const Tensor& k_projected, + const Tensor& v_projected, + Tensor& key_cache, + Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output); + +Tensor& flash_attention_kernel_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/custom_ops/op_sdpa_test.cpp b/examples/models/llama2/custom_ops/op_sdpa_test.cpp index 293359d19c9..971e8cf45cb 100644 --- a/examples/models/llama2/custom_ops/op_sdpa_test.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa_test.cpp @@ -8,7 +8,8 @@ #include -#include // Declares the operator +#include + #include #include #include @@ -28,7 +29,7 @@ exec_aten::Tensor op_scaled_dot_product_attention( exec_aten::optional scale, exec_aten::Tensor& out) { exec_aten::RuntimeContext context{}; - return torch::executor::llama::sdpa_outf( + return torch::executor::native::flash_attention_kernel_out( context, query, key, value, attn_mask, dropout_p, is_causal, scale, out); } diff --git a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp index 6ec6f429264..fa2d164fe3d 100644 --- a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp @@ -8,7 +8,7 @@ #include -#include // Declares the operator +#include // Declares the operator #include #include #include @@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache( exec_aten::optional scale, exec_aten::Tensor& out) { exec_aten::RuntimeContext context{}; - return torch::executor::llama::sdpa_with_kv_cache_outf( + return torch::executor::native::sdpa_with_kv_cache_out( context, query, key, diff --git a/examples/models/llama2/custom_ops/targets.bzl b/examples/models/llama2/custom_ops/targets.bzl index ab611125fd0..66ce6e0c04a 100644 --- a/examples/models/llama2/custom_ops/targets.bzl +++ b/examples/models/llama2/custom_ops/targets.bzl @@ -1,41 +1,4 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib") -load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper") - -def define_tests(): - codegen_function_header_wrapper("executorch/examples/models/llama2/custom_ops", "custom_ops") - - # In the long run we should really have aten variant available as well - deps = [":function_header_wrapper_custom_ops"] - generated_lib_and_op_deps = [ - ":custom_ops", - ":sdpa", - ":custom_ops_headers", - ] - runtime.cxx_test( - name = "op_sdpa_test", - srcs = [ - "op_sdpa_test.cpp", - ], - visibility = ["//executorch/..."], - deps = [ - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - "//executorch/kernels/test:test_util", - ] + generated_lib_and_op_deps + deps, - ) - runtime.cxx_test( - name = "op_sdpa_with_kv_cache_test", - srcs = [ - "op_sdpa_with_kv_cache_test.cpp", - ], - visibility = ["//executorch/..."], - deps = [ - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - "//executorch/kernels/test:test_util", - ] + generated_lib_and_op_deps + deps, - ) def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -43,7 +6,6 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ - runtime.python_library( name = "llama_custom_ops_aot_lib", srcs = [ @@ -58,71 +20,54 @@ def define_common_targets(): ], ) - runtime.export_file( - name = "custom_ops.yaml", - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - ) - - # ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~ - et_operator_library( - name = "sdpa_op", - ops = [ - "llama::sdpa.out", - ], - define_static_targets = True, - visibility = [ - "//executorch/codegen/...", - "@EXECUTORCH_CLIENTS", - ], - ) - - et_operator_library( - name = "sdpa_with_kv_cache", - ops = [ - "llama::sdpa_with_kv_cache.out", - ], - define_static_targets = True, - visibility = [ - "//executorch/codegen/...", - "@EXECUTORCH_CLIENTS", - ], - ) - runtime.cxx_library( - name = "sdpa", + name = "custom_ops", srcs = ["op_sdpa.cpp"], - deps = [ + exported_headers = ["op_sdpa.h"], + exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/optimized:libblas", "//executorch/kernels/optimized:libvec", + "//executorch/extension/kernel_util:kernel_util", "//executorch/extension/parallel:thread_parallel", "//executorch/backends/xnnpack/threadpool:threadpool", ], - compiler_flags = ["-Wno-missing-prototypes"], + compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"], visibility = [ "//executorch/...", "//executorch/examples/models/llama2/custom_ops/...", "@EXECUTORCH_CLIENTS", ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, force_static = True, ) - executorch_generated_lib( - name = "custom_ops", + runtime.cxx_test( + name = "op_sdpa_test", + srcs = [ + "op_sdpa_test.cpp", + ], + visibility = ["//executorch/..."], deps = [ - ":sdpa_op", - ":sdpa_with_kv_cache", - ":sdpa", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/kernels/test:test_util", + ":custom_ops", ], - custom_ops_yaml_target = ":custom_ops.yaml", - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", + ) + + runtime.cxx_test( + name = "op_sdpa_with_kv_cache_test", + srcs = [ + "op_sdpa_with_kv_cache_test.cpp", + ], + visibility = ["//executorch/..."], + deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/kernels/test:test_util", + ":custom_ops", ], - define_static_targets = True, ) - define_tests() diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 4d4460203c0..addeb86185e 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -61,6 +61,12 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) set(CUSTOM_OPS_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/custom_ops/libcustom_ops.a) add_library(custom_ops STATIC IMPORTED) set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_PATH}) +<<<<<<< dest: 1cc8fc2b4f1c - linliwan: [SOX] Bring back SOX over delivery e... +||||||| base: d016ce149e52 - larryliu: Consolidate EXECUTORCH_BUILD_CUSTOM ... + target_link_options_shared_lib(custom_ops_lib) +======= + target_link_options_shared_lib(custom_ops) +>>>>>>> source: 8155f0751618 - larryliu: [retake] Use new API to register cus... if(TARGET pthreadpool) set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp ../../backends/xnnpack/threadpool/cpuinfo_utils.cpp)