forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LinearAlgebraKernel.cpp
168 lines (150 loc) · 5.67 KB
/
LinearAlgebraKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include <ATen/ATen.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/cpu/Reduce.h>
#include <ATen/native/cpu/Loops.h>
namespace at { namespace native { namespace {
void addr_kernel(TensorIterator &iter,
const Scalar& beta, const Scalar& alpha) {
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
// when beta is false, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == false) {
cpu_kernel(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return alpha_val && vec1_val && vec2_val;
}
);
} else {
cpu_kernel(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
}
);
}
return;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
iter.dtype(), "addr_cpu", [&]() {
using Vec = Vectorized<scalar_t>;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
auto beta_vec = Vec(beta_val);
auto alpha_vec = Vec(alpha_val);
const scalar_t zero_val(0);
// when beta == 0, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == zero_val) {
cpu_kernel_vec(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return alpha_val * vec1_val * vec2_val;
},
[=](Vec self_vec,
Vec vec1_vec,
Vec vec2_vec) __ubsan_ignore_undefined__ {
return alpha_vec * vec1_vec * vec2_vec;
}
);
} else {
cpu_kernel_vec(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return beta_val * self_val + alpha_val * vec1_val * vec2_val;
},
[=](Vec self_vec,
Vec vec1_vec,
Vec vec2_vec) __ubsan_ignore_undefined__ {
return beta_vec * self_vec + alpha_vec * vec1_vec * vec2_vec;
}
);
}
}
);
}
template <typename scalar_t, typename acc_t=typename scalar_value_type<scalar_t>::type>
void linalg_vector_norm_kernel_cpu_impl(TensorIterator& iter, Scalar ord) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double ord_val;
if (ord.isFloatingPoint()) {
ord_val = ord.to<double>();
} else {
TORCH_CHECK(false, "linalg.vector_norm expects ord to be float");
}
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
acc_t init_val = (ord_val == -INFINITY) ? std::numeric_limits<acc_t>::infinity() : static_cast<acc_t>(0);
if (iter.numel() == 0) {
iter.output().fill_((ord_val < 0) ? INFINITY : 0);
return;
}
if (ord_val == 0) {
binary_kernel_reduce(iter, NormZeroOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == 1) {
binary_kernel_reduce(iter, NormOneOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == 2) {
binary_kernel_reduce(iter, NormTwoOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == INFINITY) {
binary_kernel_reduce(iter, AbsMaxOps<scalar_t, acc_t>(), init_val);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
} else if (ord_val == -INFINITY) {
binary_kernel_reduce(iter, AbsMinOps<scalar_t, acc_t>(), init_val);
} else {
binary_kernel_reduce(iter, NormOps<scalar_t, acc_t> { static_cast<acc_t>(ord_val) }, init_val);
}
// For complex outputs, the above kernels do not touch the imaginary values,
// so we must zero them out
if (isComplexType(iter.output().scalar_type())) {
at::imag(iter.output()).zero_();
}
}
static void linalg_vector_norm_kernel_cpu(TensorIterator& iter, Scalar ord) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "linalg_vector_norm_cpu", [&] {
linalg_vector_norm_kernel_cpu_impl<scalar_t>(iter, ord);
});
}
void unpack_pivots_cpu_kernel(
TensorIterator& iter,
int64_t dim_size
) {
if (iter.numel() == 0) {
return;
}
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
auto* unpacked_pivots_ptr = data[0];
const auto* pivots_ptr = data[1];
for (int64_t elem = 0; elem < nelems; ++elem) {
// WARNING: torch.lu returns int32 pivots,
// this behavior could change in the future.
auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(unpacked_pivots_ptr);
auto* pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr);
for (int64_t i = 0; i < dim_size; ++i) {
std::swap(
unpacked_pivots_data[i],
unpacked_pivots_data[pivots_data[i]]
);
}
unpacked_pivots_ptr += strides[0];
pivots_ptr += strides[1];
}
};
iter.for_each(loop);
}
} // anonymous namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(addr_stub, &addr_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cpu);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
}} // namespace at::native