forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MaxPooling.cpp
61 lines (52 loc) · 1.71 KB
/
MaxPooling.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/MaxPooling.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
inline void max_pool1d_kernel(
scalar_t* C10_RESTRICT op,
const scalar_t* C10_RESTRICT ip,
const PoolingParams1D& p) {
for (const auto kj : c10::irange(p.KW)) {
int64_t oj = p.valid_output_start(kj);
int64_t oe = p.valid_output_end(kj);
int64_t ij = p.index(kj, oj);
for (; oj < oe; ++oj, ij += p.SJ) {
scalar_t val = ip[ij];
bool update_max = std::isnan(val) || op[oj] < val;
op[oj] = update_max ? val : op[oj];
}
}
}
void max_pool1d_impl(
Tensor& output,
const Tensor& input,
const PoolingParams1D& p) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
const Tensor in = input.contiguous();
scalar_t* const OP = output.data_ptr<scalar_t>();
const scalar_t* const IP = in.data_ptr<scalar_t>();
// Value used for padding
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();
at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
for (const auto it : c10::irange(begin, end)) {
scalar_t* op = OP + it * p.OW;
const scalar_t* ip = IP + it * p.IW;
std::fill_n(op, p.OW, FILL);
max_pool1d_kernel(op, ip, p);
}
});
});
}
} // namespace
REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl);
} // namespace native
} // namespace at