Skip to content

Commit

Permalink
Use new API to register custom ExecuTorch kernels into ATen (#2937)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2937

Retry of D55713944
Use `WRAP_TO_ATEN` to register custom ExecuTorch kernel to PyTorch.

This PR added installation logic to `libcustom_ops_aot_lib.so` in `setup.py`. This is to make sure we can build `libcustom_ops_aot_lib.so` and install it to the correct position (`<site-packages>/executorch/examples/models/llama2/custom_ops/libcustom_ops_aot_lib.so`) and then it can be loaded by `torch.ops.load_library`.

Reviewed By: lucylq

Differential Revision: D55907749
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Apr 11, 2024
1 parent 969edc1 commit 120c145
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 112 deletions.
21 changes: 20 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)

option(EXECUTORCH_BUILD_CUSTOM "Build the custom kernels" OFF)

option(EXECUTORCH_BUILD_CUSTOM_OPS_AOT "Build the custom ops lib for AOT" OFF)

option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension"
OFF)

Expand Down Expand Up @@ -185,12 +187,19 @@ cmake_dependent_option(
cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF)

if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
set(EXECUTORCH_BUILD_CUSTOM ON)
endif()

if(EXECUTORCH_BUILD_CUSTOM)
set(EXECUTORCH_BUILD_OPTIMIZED ON)
endif()

if(EXECUTORCH_BUILD_CPUINFO)
# --- cpuinfo
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
${CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo")
set(CPUINFO_BUILD_TOOLS
OFF
Expand All @@ -212,10 +221,15 @@ if(EXECUTORCH_BUILD_CPUINFO)
CACHE STRING "")
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog")
add_subdirectory("${CPUINFO_SOURCE_DIR}")
set(CMAKE_POSITION_INDEPENDENT_CODE
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
endif()

if(EXECUTORCH_BUILD_PTHREADPOOL)
# --- pthreadpool
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
${CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool")
set(PTHREADPOOL_BUILD_TESTS
OFF
Expand All @@ -235,6 +249,8 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
CACHE STRING "")
endif()
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}")
set(CMAKE_POSITION_INDEPENDENT_CODE
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
endif()

if(NOT PYTHON_EXECUTABLE)
Expand Down Expand Up @@ -546,6 +562,9 @@ if(EXECUTORCH_BUILD_PYBIND)
list(APPEND _dep_libs custom_ops)
endif()

if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
list(APPEND _dep_libs custom_ops_aot_lib)
endif()
# compile options for pybind

set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti
Expand All @@ -559,7 +578,7 @@ if(EXECUTORCH_BUILD_PYBIND)
target_include_directories(util PUBLIC ${_common_include_directories}
${TORCH_INCLUDE_DIRS})
target_compile_options(util PUBLIC ${_pybind_compile_options})
target_link_libraries(util PRIVATE torch c10 executorch)
target_link_libraries(util PRIVATE torch c10 executorch_no_prim_ops)

# pybind portable_lib
pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp)
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ runtime.python_library(
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib",
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
],
)

Expand Down Expand Up @@ -52,6 +52,7 @@ runtime.python_binary(
main_module = "executorch.examples.models.llama2.export_llama",
# visibility = ["//executorch/examples/..."],
preload_deps = [
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
Expand Down
20 changes: 18 additions & 2 deletions examples/models/llama2/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
endif()

set(_common_compile_options -Wno-deprecated-declarations)
set(_common_compile_options -Wno-deprecated-declarations -fPIC)

include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
Expand All @@ -44,7 +44,7 @@ include(${EXECUTORCH_SRCS_FILE})
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

# Custom op libraries
set(custom_ops_libs extension_module)
set(custom_ops_libs executorch_no_prim_ops)
list(APPEND custom_ops_libs pthreadpool)
list(APPEND custom_ops_libs cpuinfo)
list(APPEND custom_ops_libs cpublas)
Expand Down Expand Up @@ -76,3 +76,19 @@ target_compile_options(custom_ops PUBLIC ${_common_compile_options}
-DET_USE_THREADPOOL)

install(TARGETS custom_ops DESTINATION lib)

if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
# Add a AOT library
find_package(Torch CONFIG REQUIRED)
add_library(custom_ops_aot_lib SHARED
${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp)
target_include_directories(custom_ops_aot_lib
PUBLIC "${_common_include_directories}")
target_include_directories(
custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include")
target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch)
target_compile_options(custom_ops_aot_lib PUBLIC -Wno-deprecated-declarations
-fPIC -frtti -fexceptions)

install(TARGETS custom_ops_aot_lib DESTINATION lib)
endif()
107 changes: 107 additions & 0 deletions examples/models/llama2/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>

#include <torch/library.h>

namespace torch {
namespace executor {

namespace native {

Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
const Tensor& k_projected,
const Tensor& v_projected,
Tensor& key_cache,
Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
exec_aten::RuntimeContext context{};
return torch::executor::native::sdpa_with_kv_cache_out(
context,
q_projected,
k_projected,
v_projected,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
dropout_p,
is_causal,
scale,
output);
}

at::Tensor sdpa_with_kv_cache_aten(
const at::Tensor& q_projected,
const at::Tensor& k_projected,
const at::Tensor& v_projected,
at::Tensor& key_cache,
at::Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<at::Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<double> scale) {
auto output = at::empty_like(q_projected);
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
(q_projected,
k_projected,
v_projected,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
dropout_p,
is_causal,
scale,
output);
return output;
}

} // namespace native
} // namespace executor
} // namespace torch

TORCH_LIBRARY(llama, m) {
m.def(
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor");
m.def(
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
}

TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
m.impl(
"sdpa_with_kv_cache.out",
WRAP_TO_ATEN(
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
}
111 changes: 20 additions & 91 deletions examples/models/llama2/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,29 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch
# C++ APIs for registration so here we need to import the shared library.
# This is only needed for OSS.

import logging
from pathlib import Path

import torch
from torch.library import impl, impl_abstract

custom_ops_lib = torch.library.Library("llama", "DEF")
custom_ops_lib.define(
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"
)
from torch.library import impl

custom_ops_lib.define(
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"
)
try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
except:
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
logging.info(f"Loading custom ops library: {libs[0]}")
torch.ops.load_library(libs[0])
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None

custom_ops_lib = torch.library.Library("llama", "IMPL")


def _validate_params(
Expand Down Expand Up @@ -118,82 +126,3 @@ def sdpa_with_kv_cache_meta(
)

return torch.empty_like(query)


@impl(custom_ops_lib, "sdpa_with_kv_cache", "CompositeExplicitAutograd")
def sdpa_with_kv_cache(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask=None,
drpout_p=0.0,
is_causal=False,
scale=None,
):
_validate_params(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)

if attn_mask is not None:
attn_mask = attn_mask[start_pos].view((1, -1))
attn_mask = attn_mask[:, : start_pos + seq_len]
q = query.transpose(1, 2)
key_cache[:, start_pos] = key
value_cache[:, start_pos] = value

sliced_k_cache = key_cache
sliced_v_cache = value_cache
sliced_k_cache = sliced_k_cache[:, : start_pos + seq_len, :, :]
sliced_v_cache = sliced_v_cache[:, : start_pos + seq_len, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
return out


@impl_abstract("llama::sdpa_with_kv_cache.out")
def sdpa_with_kv_cache_out(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
out,
):
out = sdpa_with_kv_cache_meta(
query,
key,
value,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)
return out
Loading

0 comments on commit 120c145

Please sign in to comment.