-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use new API to register custom ExecuTorch kernels into ATen (#2937)
Summary: Pull Request resolved: #2937 Retry of D55713944 Use `WRAP_TO_ATEN` to register custom ExecuTorch kernel to PyTorch. This PR added installation logic to `libcustom_ops_aot_lib.so` in `setup.py`. This is to make sure we can build `libcustom_ops_aot_lib.so` and install it to the correct position (`<site-packages>/executorch/examples/models/llama2/custom_ops/libcustom_ops_aot_lib.so`) and then it can be loaded by `torch.ops.load_library`. Reviewed By: lucylq Differential Revision: D55907749
- Loading branch information
1 parent
203c9d2
commit b1e027e
Showing
9 changed files
with
217 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h> | ||
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h> | ||
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> | ||
|
||
#include <torch/library.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
namespace native { | ||
|
||
Tensor& sdpa_with_kv_cache_out_no_context( | ||
const Tensor& q_projected, | ||
const Tensor& k_projected, | ||
const Tensor& v_projected, | ||
Tensor& key_cache, | ||
Tensor& value_cache, | ||
const int64_t start_pos, | ||
const int64_t seq_len, | ||
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const optional<Tensor> attn_mask, | ||
const double dropout_p, | ||
const bool is_causal, | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const optional<double> scale, | ||
Tensor& output) { | ||
exec_aten::RuntimeContext context{}; | ||
return torch::executor::native::sdpa_with_kv_cache_out( | ||
context, | ||
q_projected, | ||
k_projected, | ||
v_projected, | ||
key_cache, | ||
value_cache, | ||
start_pos, | ||
seq_len, | ||
attn_mask, | ||
dropout_p, | ||
is_causal, | ||
scale, | ||
output); | ||
} | ||
|
||
at::Tensor sdpa_with_kv_cache_aten( | ||
const at::Tensor& q_projected, | ||
const at::Tensor& k_projected, | ||
const at::Tensor& v_projected, | ||
at::Tensor& key_cache, | ||
at::Tensor& value_cache, | ||
const int64_t start_pos, | ||
const int64_t seq_len, | ||
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const c10::optional<at::Tensor> attn_mask, | ||
const double dropout_p, | ||
const bool is_causal, | ||
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy | ||
const c10::optional<double> scale) { | ||
auto output = at::empty_like(q_projected); | ||
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) | ||
(q_projected, | ||
k_projected, | ||
v_projected, | ||
key_cache, | ||
value_cache, | ||
start_pos, | ||
seq_len, | ||
attn_mask, | ||
dropout_p, | ||
is_causal, | ||
scale, | ||
output); | ||
return output; | ||
} | ||
|
||
} // namespace native | ||
} // namespace executor | ||
} // namespace torch | ||
|
||
TORCH_LIBRARY(llama, m) { | ||
m.def( | ||
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " | ||
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " | ||
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"); | ||
m.def( | ||
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " | ||
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " | ||
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { | ||
m.impl( | ||
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); | ||
m.impl( | ||
"sdpa_with_kv_cache.out", | ||
WRAP_TO_ATEN( | ||
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.