forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FractionalMaxPool2d.cpp
398 lines (348 loc) · 11.5 KB
/
FractionalMaxPool2d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <tuple>
#include <vector>
namespace at {
namespace meta {
TORCH_META_FUNC(fractional_max_pool2d) (
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples
) {
TORCH_CHECK(
pool_size.size() == 2,
"fractional_max_pool2d: kernel_size must either be a single Int or tuple of Ints")
TORCH_CHECK(
output_size.size() == 2,
"fractional_max_pool2d: output_size must either be a single Int or tuple of Ints")
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t heightDim = 1;
int64_t widthDim = 2;
int64_t outputH = output_size[0];
int64_t outputW = output_size[1];
int64_t poolSizeH = pool_size[0];
int64_t poolSizeW = pool_size[1];
int64_t ndims = input.ndimension();
TORCH_CHECK(input.numel() > 0 && (ndims == 3 || ndims == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: ",
ndims);
if (ndims == 4) {
numBatch = input.size(0);
planeDim++;
heightDim++;
widthDim++;
}
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputH = input.size(heightDim);
int inputW = input.size(widthDim);
TORCH_CHECK(outputH + poolSizeH - 1 <= inputH,
"fractional_max_pool2d(): pool height ", poolSizeH,
" too large relative to input height ", inputH);
TORCH_CHECK(outputW + poolSizeW - 1 <= inputW,
"fractional_max_pool2d(): pool width ", poolSizeW,
" too large relative to input width ", inputW);
if (ndims == 3) {
set_output(0, {numPlanes, outputH, outputW}, input.options());
/* indices will contain the locations for each output point */
set_output(1, {numPlanes, outputH, outputW}, input.options().dtype(kLong));
} else {
set_output(0, {numBatch, numPlanes, outputH, outputW}, input.options());
/* indices will contain the locations for each output point */
set_output(1, {numBatch, numPlanes, outputH, outputW}, input.options().dtype(kLong));
}
}
} // namespace meta
namespace native {
namespace {
template <typename scalar_t>
static std::vector<int> fractional_max_pool2d_generate_intervals(
scalar_t sample,
int inputSize,
int outputSize,
int poolSize) {
std::vector<int> sequence(outputSize);
if (outputSize > 1) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
static_cast<scalar_t>(outputSize - 1);
for (int i = 0; i < outputSize - 1; ++i) {
sequence[i] =
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
sequence[outputSize - 1] = inputSize - poolSize;
return sequence;
}
template <typename scalar_t>
static void fractional_max_pool2d_out_single_batch_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int numPlanes,
int inputW, int inputH,
int outputW, int outputH,
int poolSizeW, int poolSizeH) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (auto plane = start; plane < end; ++plane) {
/* each plane contains 2 random samples, one for W and one for H */
scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
/* Generate interval sequence */
auto sequenceW = fractional_max_pool2d_generate_intervals<scalar_t>(
randomSamplesForPlane[0], inputW, outputW, poolSizeW);
auto sequenceH = fractional_max_pool2d_generate_intervals<scalar_t>(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
/* loop over output */
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int h, w;
scalar_t* inputForPlane = input + plane * inputW * inputH;
scalar_t* outputForPlane = output + plane * outputW * outputH;
int64_t* indicesForPlane = indices + plane * outputW * outputH;
for (h = 0; h < outputH; ++h) {
int inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
int inputWStart = sequenceW[w];
int h2 = inputHStart, w2 = inputWStart;
scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
int64_t maxIndex = h2 * inputW + w2;
for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
AT_ASSERT(h2 >= 0 && h2 < inputH);
AT_ASSERT(w2 >= 0 && w2 < inputW);
int planeIndex = h2 * inputW + w2;
scalar_t val = inputForPlane[planeIndex];
if (val > maxVal || std::isnan(val)) {
maxVal = val;
maxIndex = planeIndex;
}
}
}
outputForPlane[h * outputW + w] = maxVal;
indicesForPlane[h * outputW + w] = maxIndex;
}
}
}
});
}
template <typename scalar_t>
static void fractional_max_pool2d_out_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int numBatch, int numPlanes,
int inputW, int inputH,
int outputW, int outputH,
int poolSizeW, int poolSizeH) {
if(numBatch == 1) {
fractional_max_pool2d_out_single_batch_frame<scalar_t>(
input,
output,
indices,
randomSamples,
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (auto batch = start; batch < end; ++batch) {
fractional_max_pool2d_out_single_batch_frame<scalar_t>(
input + batch * numPlanes * inputH * inputW,
output + batch * numPlanes * outputH * outputW,
indices + batch * numPlanes * outputH * outputW,
randomSamples + batch * numPlanes * 2,
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
}
});
}
} // anonymous namespace
TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu) (
const at::Tensor& input_,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples,
const at::Tensor& output,
const at::Tensor& indices) {
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t heightDim = 1;
int64_t widthDim = 2;
int64_t outputH = output_size[0]; // output.size(heightDim)
int64_t outputW = output_size[1]; // output.size(widthDim)
int64_t poolSizeH = pool_size[0];
int64_t poolSizeW = pool_size[1];
/* get contiguous input */
auto input = input_.contiguous();
int64_t ndims = input.ndimension();
if (ndims == 4) {
numBatch = input.size(0);
planeDim++;
heightDim++;
widthDim++;
}
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputH = input.size(heightDim);
int64_t inputW = input.size(widthDim);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"fractional_max_pool2d_out_frame", [&] {
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
auto randomSamples_data = randomSamples.data_ptr<scalar_t>();
fractional_max_pool2d_out_frame<scalar_t>(
input_data,
output_data,
indices_data,
randomSamples_data,
numBatch, numPlanes,
inputW, inputH,
outputW, outputH,
poolSizeW, poolSizeH);
}
);
}
namespace {
template <typename scalar_t>
static void fractional_max_pool2d_backward_out_single_batch_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int numPlanes,
int inputW, int inputH,
int outputW, int outputH) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (auto plane = start; plane < end; plane++) {
scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
int64_t* indicesForPlane = indices + plane * outputW * outputH;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int h, w;
for (h = 0; h < outputH; ++h) {
for (w = 0; w < outputW; ++w) {
int outputIndex = h * outputW + w;
int64_t index = indicesForPlane[outputIndex];
AT_ASSERT(index >= 0 && index < inputW * inputH);
gradInputForPlane[index] += gradOutputForPlane[outputIndex];
}
}
}
});
}
template <typename scalar_t>
static void fractional_max_pool2d_backward_out_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int numBatch, int numPlanes,
int inputW, int inputH,
int outputW, int outputH) {
if(numBatch == 1) {
fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
gradInput, gradOutput, indices,
numPlanes,
inputW, inputH, outputW, outputH
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (auto batch = start; batch < end; ++batch) {
fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
gradInput + batch * numPlanes * inputH * inputW,
gradOutput + batch * numPlanes * outputH * outputW,
indices + batch * numPlanes * outputH * outputW,
numPlanes, inputW, inputH, outputW, outputH);
}
});
}
Tensor& fractional_max_pool2d_backward_out_cpu_template(
const at::Tensor& input,
const at::Tensor& gradOutput_,
at::Tensor& gradInput,
IntArrayRef output_size,
IntArrayRef pool_size /* unused */,
const at::Tensor& indices) {
int numBatch = 1;
int planeDim = 0;
int heightDim = 1;
int widthDim = 2;
int outputH = output_size[0];
int outputW = output_size[1];
int ndims = input.ndimension();
if (ndims == 4) {
numBatch = input.size(0);
planeDim = 1;
heightDim++;
widthDim++;
}
/* sizes */
int numPlanes = input.size(planeDim);
int inputH = input.size(heightDim);
int inputW = input.size(widthDim);
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
TORCH_CHECK(outputW == gradOutput.size(widthDim),
"fractional_max_pool2d_backward(): gradOutput width unexpected");
TORCH_CHECK(outputH == gradOutput.size(heightDim),
"fractional_max_pool2d_backward(): gradOutput height unexpected");
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
/* backprop */
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] {
auto gradInput_data = gradInput.data_ptr<scalar_t>();
auto gradOutput_data = gradOutput.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
fractional_max_pool2d_backward_out_frame<scalar_t>(
gradInput_data,
gradOutput_data,
indices_data,
numBatch, numPlanes,
inputW, inputH,
outputW, outputH
);
}
);
return gradInput;
}
} // namespace
Tensor& fractional_max_pool2d_backward_out_cpu(const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices,
at::Tensor& gradInput)
{
gradInput.resize_as_(input);
fractional_max_pool2d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
Tensor fractional_max_pool2d_backward_cpu(
const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices)
{
Tensor gradInput = at::empty({0}, input.options());
fractional_max_pool2d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
} // at::native
} // at