Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::trunc, aten::xlogy and thieir variants #697

Merged
merged 11 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/native/xpu/sycl/BinaryKernels.h>
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/BinaryShiftOpsKernels.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
Expand Down Expand Up @@ -51,6 +52,7 @@ REGISTER_XPU_DISPATCH(fmax_stub, &xpu::fmax_kernel);
REGISTER_XPU_DISPATCH(fmin_stub, &xpu::fmin_kernel);
REGISTER_XPU_DISPATCH(lshift_stub, &xpu::lshift_kernel);
REGISTER_XPU_DISPATCH(rshift_stub, &xpu::rshift_kernel);
REGISTER_XPU_DISPATCH(xlogy_stub, &xpu::xlogy_kernel);

TORCH_IMPL_FUNC(add_out_xpu)
(const Tensor& self,
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ void aminmax_impl(
Tensor& min,
Tensor& max) {
auto dtype = self.scalar_type();
TensorIterator iter = make_reduction(
"aminmax_xpu", min, max, self, dim_opt, keepdim, dtype);
TensorIterator iter =
make_reduction("aminmax_xpu", min, max, self, dim_opt, keepdim, dtype);
if (iter.numel() != 0) {
native::xpu::aminmax_kernel(iter);
}
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,7 @@ REGISTER_XPU_DISPATCH(nan_to_num_stub, &xpu::nan_to_num_kernel);
REGISTER_XPU_DISPATCH(round_stub, &xpu::round_kernel);
REGISTER_XPU_DISPATCH(round_decimals_stub, &xpu::round_decimals_kernel);
REGISTER_XPU_DISPATCH(floor_stub, &xpu::floor_kernel);
REGISTER_XPU_DISPATCH(trunc_stub, &xpu::trunc_kernel);

} // namespace native
} // namespace at
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"triangular_solve.X",
"tril_indices",
"triu_indices",
"trunc.out",
"upsample_bicubic2d_backward.grad_input",
"_upsample_bilinear2d_aa.out",
"upsample_nearest3d.out",
Expand All @@ -292,7 +291,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"upsample_trilinear3d.out",
"_validate_compressed_sparse_indices",
"vdot",
"xlogy.OutTensor",
"_upsample_bicubic2d_aa.out",
};
for (auto& op_name : fallback_list) {
Expand Down
24 changes: 23 additions & 1 deletion src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#include <ATen/native/TensorIterator.h>
#include <comm/xpu_aten.h>

#include <ATen/NumericUtils.h>
#include <ATen/native/xpu/sycl/Loops.h>

#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>

namespace at::native::xpu {

template <typename scalar_t>
struct MSEFunctor {
scalar_t operator()(scalar_t a, scalar_t b) const {
Expand Down Expand Up @@ -72,4 +72,26 @@ void huber_kernel(TensorIterator& iter, double delta) {
});
}

template <typename scalar_t>
struct XlogyFunctor {
scalar_t operator()(scalar_t x, scalar_t y) const {
if (at::_isnan(y)) {
return NAN;
}
if (x == 0) {
return 0;
}
return x * std::log(y);
}
};

void xlogy_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"xlogy_xpu",
[&]() { gpu_kernel_with_scalars(iter, XlogyFunctor<scalar_t>()); });
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ TORCH_XPU_API void smooth_l1_kernel(TensorIteratorBase& iter, double beta);

TORCH_XPU_API void huber_kernel(TensorIterator& iter, double delta);

TORCH_XPU_API void xlogy_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
37 changes: 37 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,41 @@ void floor_kernel(TensorIteratorBase& iter) {
});
}

// We manually overload trunc because std::trunc does not work with std::complex
// types and ROCm.
template <typename scalar_t>
inline scalar_t trunc_wrapper(scalar_t a) {
return static_cast<scalar_t>(std::truncf(static_cast<float>(a)));
}

inline double trunc_wrapper(double a) {
return std::trunc(a);
}

inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
return c10::complex<float>(
std::truncf(static_cast<float>(a.real())),
std::truncf(static_cast<float>(a.imag())));
}

inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
return c10::complex<double>(
std::trunc(static_cast<double>(a.real())),
std::trunc(static_cast<double>(a.imag())));
}

template <typename scalar_t>
struct TruncFunctor {
scalar_t operator()(scalar_t a) const {
return trunc_wrapper(a);
}
};

void trunc_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "trunc_xpu", [&]() {
gpu_kernel(iter, TruncFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ TORCH_XPU_API void round_decimals_kernel(

TORCH_XPU_API void frac_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void trunc_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@
"sign",
"signbit",
"round",
"trunc",
"xlogy",
"nn.functional.embedding_bag",
"bucketize",
"searchsorted",
Expand Down
43 changes: 43 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4435,6 +4435,29 @@
XPU: logit_out
tags: pointwise

- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: xlogy.OutTensor
variants: function, method
tags: pointwise

# xlogy: inplace variant
- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function, method
structured_delegate: xlogy.OutTensor
tags: pointwise

# xlogy: out variant
- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
variants: function
dispatch:
XPU: xlogy_out
tags: pointwise

- func: erfinv(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: erfinv.out
Expand Down Expand Up @@ -4598,6 +4621,26 @@
XPU: floor_out
tags: pointwise

- func: trunc(Tensor self) -> Tensor
structured_delegate: trunc.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: [core, pointwise]

- func: trunc_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: trunc.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: pointwise

- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
XPU: trunc_out
tags: pointwise

- func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
structured: True
Expand Down