forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Bucketization.cpp
142 lines (122 loc) · 6.03 KB
/
Bucketization.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/native/BucketizationUtils.h>
/* Implement a TF like searchsorted and a bucketize function running on cpu
*
* - torch.searchsorted(sorted_sequence, values, right=False, out_int32=False)
* sorted_sequence - N*D or 1D (apply to all values) tensor containing sorted sequences in last dimension
* values - N*D tensor or a Scalar (when sorted_sequence is 1D) containing the search values
* right - corresponding to lower bound if False and upper bound if True
* out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True.
*
* - torch.bucketize(values, boundaries, right=False, out_int32=False)
* values - N*D tensor or a Scalar containing the search value
* boundaries - 1D tensor containing a sorted sequences
* right - corresponding to lower bound if False and upper bound if True
* out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True.
*
* - Restrictions are defined in searchsorted_pre_check()
*/
namespace at {
namespace native {
namespace {
// minimal size for searchsorted_cpu_contiguous to run parallel (multithread)
constexpr int64_t SEARCHSORTED_GRAIN_SIZE = 200;
// customized lower_bound func to ensure the low bound of 'nan', 'inf' etc. be the end of boundary
// std::lower_bound can not be used here since its customized comparator need strict weak ordering
template<typename input_t>
const input_t* cus_lower_bound(const input_t* start, const input_t* end, input_t val) {
while (start < end) {
const input_t* mid = start + ((end - start) >> 1);
if (!(*mid >= val)) {
start = mid + 1;
}
else {
end = mid;
}
}
return start;
}
template<typename input_t, typename output_t>
void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tensor& boundaries, const bool& right) {
int64_t numel_in = input.numel();
bool is_scalar_input = input.dim() == 0 && numel_in == 1;
// inner most dim size of input and boundaries
int64_t idim_in = is_scalar_input ? 1 : input.sizes().back();
int64_t idim_bd = boundaries.sizes().back();
const input_t *data_in = input.data_ptr<input_t>();
const input_t *data_bd = boundaries.data_ptr<input_t>();
output_t *data_out = result.data_ptr<output_t>();
bool is_1d_boundaries = boundaries.dim() == 1;
at::parallel_for(0, numel_in, SEARCHSORTED_GRAIN_SIZE, [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
// If boundaries tensor is 1d, we always search the entire boundary tensor
int64_t start_bd = is_1d_boundaries ? 0 : i / idim_in * idim_bd;
const input_t *data_bd_start = &data_bd[start_bd];
int64_t pos = !right ?
cus_lower_bound(data_bd_start, data_bd_start + idim_bd, data_in[i]) - data_bd_start :
std::upper_bound(data_bd_start, data_bd_start + idim_bd, data_in[i]) - data_bd_start;
// type conversion might happen here
data_out[i] = pos;
}
});
}
void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right) {
if (!out_int32) {
AT_DISPATCH_ALL_TYPES(input.scalar_type(), "searchsorted_out_cpu", [&] {
searchsorted_cpu_contiguous<scalar_t, int64_t>(result, input, boundaries, right);
});
}
else {
AT_DISPATCH_ALL_TYPES(input.scalar_type(), "searchsorted_out_cpu", [&] {
searchsorted_cpu_contiguous<scalar_t, int>(result, input, boundaries, right);
});
}
}
}
Tensor& searchsorted_out_cpu(const Tensor& sorted_sequence, const Tensor& self, bool out_int32, bool right, Tensor& result) {
searchsorted_pre_check(sorted_sequence, self, result, out_int32);
if (result.numel() == 0) {
result.resize_(self.sizes());
}
if (self.numel() == 0) {
return result;
}
if (sorted_sequence.is_contiguous() && self.is_contiguous() && sorted_sequence.dtype() == self.dtype()) {
dispatch(result, self, sorted_sequence, out_int32, right);
return result;
}
Tensor trimmed_input;
Tensor trimmed_boundaries;
searchsorted_maybe_trim_input_tensors(trimmed_input, trimmed_boundaries, self, sorted_sequence);
const Tensor& final_input = trimmed_input.defined() ? trimmed_input : self;
const Tensor& final_boundaries = trimmed_boundaries.defined() ? trimmed_boundaries : sorted_sequence;
dispatch(result, final_input, final_boundaries, out_int32, right);
return result;
}
Tensor searchsorted_cpu(const Tensor& sorted_sequence, const Tensor& self, bool out_int32, bool right) {
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
at::native::searchsorted_out_cpu(sorted_sequence, self, out_int32, right, result);
return result;
}
Tensor searchsorted_cpu(const Tensor& sorted_sequence, const Scalar& self, bool out_int32, bool right) {
return searchsorted_cpu(sorted_sequence, searchsorted_scalar_tensor(self, sorted_sequence.device()), out_int32, right);
}
Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) {
TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, result);
return result;
}
Tensor bucketize_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right) {
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
at::native::bucketize_out_cpu(self, boundaries, out_int32, right, result);
return result;
}
Tensor bucketize_cpu(const Scalar& self, const Tensor& boundaries, bool out_int32, bool right) {
return bucketize_cpu(searchsorted_scalar_tensor(self, boundaries.device()), boundaries, out_int32, right);
}
}} // namespace at::native