forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LogAddExpKernel.cu
52 lines (46 loc) · 1.65 KB
/
LogAddExpKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/AccumulateType.h>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at { namespace native {
void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16,
iter.dtype(), "logaddexp_cuda",
[&]() {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
if (::isinf(static_cast<accscalar_t>(a)) && a == b) {
return a;
}
else {
scalar_t m = ::max(a, b);
return m + ::log((scalar_t)(1.0) + ::exp(-::abs(a - b)));
}
});
});
}
void logaddexp2_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16,
iter.dtype(), "logaddexp2_cuda",
[&]() {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
if (::isinf(static_cast<accscalar_t>(a)) && a == b) {
return a;
}
else {
scalar_t m = ::max(a, b);
return m + ::log2((scalar_t)(1.0) + ::pow((scalar_t)(2.0), -::abs(a - b)));
}
});
});
}
REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel_cuda);
REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda);
}} // namespace at::native