Skip to content

Commit

Permalink
Add aten::_nested_tensor_softmax_with_shape (#1323)
Browse files Browse the repository at this point in the history
Part of #1141.

Depends on pytorch/pytorch#145467.

- `_nested_tensor_softmax_with_shape`
  • Loading branch information
min-jean-cho authored Jan 26, 2025
1 parent a6f4c32 commit b6786e3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
file(GLOB xpu_h "xpu/*.h")
file(GLOB xpu_cpp "xpu/*.cpp")
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
Expand Down
20 changes: 20 additions & 0 deletions src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <ATen/ATen.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>

namespace at::native {

Tensor NestedTensor_softmax_dropout_xpu(
const Tensor& self,
const Tensor& query) {
std::optional<Tensor> attn_mask;

attn_mask = NestedTensor_to_mask(query, 2, self.size(2));
attn_mask = attn_mask->to(query.device(), /*non-blocking=*/true);
return _masked_softmax(
self,
*attn_mask,
self.dim() - 1,
/*mask type */ 1); // NestedTensor_to_mask produces a BxT mask
}

} // namespace at::native
5 changes: 5 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4437,6 +4437,11 @@
XPU: nested_from_padded_xpu
autogen: _nested_from_padded.out

- func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor
dispatch:
NestedTensorXPU: NestedTensor_softmax_dropout_xpu
tags: nondeterministic_seeded

- func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
structured: True
Expand Down

0 comments on commit b6786e3

Please sign in to comment.