forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
EmbeddingBag.cu
556 lines (474 loc) · 21.8 KB
/
EmbeddingBag.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCAtomics.cuh>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/device_vector.h>
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
#include <c10/macros/Macros.h>
namespace at {
namespace native {
namespace {
constexpr int MODE_SUM = 0;
constexpr int MODE_MEAN = 1;
constexpr int MODE_MAX = 2;
std::pair<Tensor, Tensor> promoteIndicesAndOffsets(
const Tensor& indices,
const Tensor& offsets) {
const auto commonType =
promoteTypes(offsets.scalar_type(), indices.scalar_type());
return {
indices.scalar_type() == commonType ? indices
: indices.toType(commonType),
offsets.scalar_type() == commonType ? offsets
: offsets.toType(commonType)};
}
// This kernel assumes that all input tensors except `weight` and
// per_sample_weights are contiguous.
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_updateOutputKernel_max(
index_t *input, index_t *offsets, scalar_t *weight, scalar_t *output,
index_t *offset2bag, int64_t numIndices, int64_t numBags,
int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
index_t *bag_size, index_t *max_indices,
index_t padding_idx) {
// the strategy here is that each bag x feature is handled by a single thread
int64_t chunksPerBag = THCCeilDiv(featureSize, (int64_t)blockDim.x);
int64_t numChunks = numBags * chunksPerBag;
int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
int64_t chunkStride = gridDim.x * blockDim.y;
for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
if (featureDim < featureSize) {
int64_t bag = chunk / chunksPerBag;
scalar_t *weightFeat = weight + featureDim * weight_stride1;
int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
CUDA_KERNEL_ASSERT(end >= begin);
scalar_t weightFeatMax = 0;
int64_t bag_size_ = 0;
int64_t maxWord = -1;
for (int64_t emb = begin; emb < end; emb++) {
bool pad = (input[emb] == padding_idx);
const int64_t weightRow = input[emb] * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
if (bag_size_ == 0 || weightValue > weightFeatMax) {
weightFeatMax = pad ? weightFeatMax : weightValue;
maxWord = pad ? maxWord : input[emb];
}
bag_size_ += pad ? 0 : 1;
if (featureDim == 0) {
offset2bag[emb] = bag;
}
}
bag_size[bag] = bag_size_;
max_indices[bag * featureSize + featureDim] = maxWord;
output[bag * featureSize + featureDim] = weightFeatMax;
}
}
}
// This kernel assumes that all input tensors except `weight` and
// per_sample_weights are contiguous.
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_updateOutputKernel_sum_mean(
index_t *input, index_t *offsets, scalar_t *weight, scalar_t *output,
index_t *offset2bag, int64_t numIndices, int64_t numBags,
int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
int mode, index_t *bag_size,
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
index_t padding_idx) {
// the strategy here is that each bag x feature is handled by a single thread
using accscalar_t = acc_type<scalar_t, true>;
int64_t chunksPerBag = THCCeilDiv(featureSize, (int64_t)blockDim.x);
int64_t numChunks = numBags * chunksPerBag;
int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
int64_t chunkStride = gridDim.x * blockDim.y;
for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
if (featureDim < featureSize) {
int64_t bag = chunk / chunksPerBag;
scalar_t *weightFeat = weight + featureDim * weight_stride1;
int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
CUDA_KERNEL_ASSERT(end >= begin);
accscalar_t weightFeatSum = 0;
int64_t bag_size_ = 0;
for (int64_t emb = begin; emb < end; emb++) {
bool pad = (input[emb] == padding_idx);
const int64_t weightRow = input[emb] * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
if (per_sample_weights) {
accscalar_t scaleWeightBy = static_cast<accscalar_t>(
per_sample_weights[emb * per_sample_weights_stride]);
weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue);
} else {
weightFeatSum += static_cast<accscalar_t>(weightValue);
}
bag_size_ += pad ? 0 : 1;
if (featureDim == 0) {
offset2bag[emb] = bag;
}
}
if (mode == MODE_MEAN) {
if (bag_size_ != 0) {
weightFeatSum = weightFeatSum / static_cast<accscalar_t>(bag_size_);
}
}
bag_size[bag] = bag_size_;
output[bag * featureSize + featureDim] = static_cast<scalar_t>(weightFeatSum);
}
}
}
Tensor embedding_bag_backward_cuda_sum_avg(
const Tensor &grad,
const Tensor &indices,
const Tensor &offset2bag,
const Tensor &bag_size,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
const Tensor& per_sample_weights,
int64_t padding_idx) {
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ptrdiff_t numel = indices.numel();
if (numel == 0) {
// all empty bags
return at::zeros({num_weights, grad.size(1)}, grad.options());
}
int64_t stride = grad_weight.stride(0);
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor count;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
using device_ptr = thrust::device_ptr<index_t>;
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<index_t>(0);
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
ThrustLTOp<index_t>());
}
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
auto count_data = device_ptr(count.data_ptr<index_t>());
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
thrust::make_constant_iterator(1),
count_data);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy, thrust::make_reverse_iterator(sorted_data + numel),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + numel),
thrust::make_reverse_iterator(count_data + numel),
thrust::equal_to<index_t>(), thrust::maximum<index_t>());
}
});
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
count, num_weights, padding_idx, mode == MODE_MEAN, offset2bag,
bag_size, per_sample_weights);
}
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_accGradParametersKernel_max(
index_t *max_indices, scalar_t *gradOutput,
scalar_t *gradWeight, int64_t stride, int64_t numBags,
index_t padding_idx) {
using accscalar_t = acc_type<scalar_t, true>;
int64_t chunksPerBag = THCCeilDiv(stride, (int64_t)blockDim.x);
int64_t numChunks = numBags * chunksPerBag;
int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
int64_t chunkStride = gridDim.x * blockDim.y;
for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
if (featureDim < stride) {
int64_t bag = chunk / chunksPerBag;
index_t word_idx = max_indices[bag * stride + featureDim];
if (word_idx >= 0 && word_idx != padding_idx) {
// If bag is empty, we have max_indices[idx] set to -1 in forward.
gpuAtomicAdd(&(gradWeight[word_idx * stride + featureDim]),
gradOutput[bag * stride + featureDim]);
}
}
}
}
Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
const Tensor &max_indices,
int64_t num_weights,
int64_t padding_idx) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("embedding_bag_backward_cuda_max");
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
int64_t stride = grad_weight.stride(0);
int64_t numBags = grad.size(0);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#ifdef __HIP_PLATFORM_HCC__
dim3 block = dim3(64, 4);
#else
dim3 block = dim3(32, 8);
#endif
int grid = 1024;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () {
EmbeddingBag_accGradParametersKernel_max<
scalar_t, index_t><<<grid, block, 0, stream>>>(
max_indices.data_ptr<index_t>(), grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), stride, numBags,
padding_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return grad_weight;
}
}
// Assumes all input tensors are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cuda(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt,
bool include_last_offset, int64_t padding_idx) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
return _embedding_bag_cuda(
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx);
}
// Assumes all input tensors are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cuda(const Tensor &weight, const Tensor &indices_,
const Tensor &offsets_, const bool scale_grad_by_freq,
const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt,
bool include_last_offset, int64_t padding_idx) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
Tensor indices, offsets;
std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt});
checkSameType("embedding_bag_cuda", indices_arg, offsets_arg);
auto weight_arg = TensorArg(weight, "weight", 1);
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
int64_t numIndices = indices.size(0);
int64_t numBags = offsets.size(0);
if (include_last_offset) {
// Check https://github.com/pytorch/pytorch/issues/29019
// We plan to add one more element in offsets, which is equal to the size of
// indices. Currently for cuda devices, we still use the legacy
// implementation even this flag is enabled.
TORCH_CHECK(
numBags >= 1, "include_last_offset: numBags should be at least 1");
numBags -= 1;
}
int64_t featureSize = weight.size(1);
auto bag_size = at::empty(offsets.sizes(), indices.options());
auto offset2bag =
at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0]
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto output = at::empty({numBags, featureSize}, weight.options());
Tensor max_indices;
if (mode == MODE_MAX) {
max_indices = at::empty({numBags, featureSize}, indices.options());
} else {
// No need to allocate if we aren't doing a backwards pass
max_indices = at::empty({0}, indices.options());
}
#ifdef __HIP_PLATFORM_HCC__
dim3 block = dim3(64, 4);
#else
dim3 block = dim3(32, 8);
#endif
int grid = 1024;
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
if (mode == MODE_MAX) {
EmbeddingBag_updateOutputKernel_max<scalar_t, index_t><<<grid, block, 0, stream>>>(
indices.data_ptr<index_t>(), offsets.data_ptr<index_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), bag_size.data_ptr<index_t>(),
max_indices.data_ptr<index_t>(),
padding_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
EmbeddingBag_updateOutputKernel_sum_mean<scalar_t, index_t><<<grid, block, 0, stream>>>(
indices.data_ptr<index_t>(), offsets.data_ptr<index_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<index_t>(),
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
padding_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
});
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
}
Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &indices,
const Tensor &offset2bag,
const Tensor &bag_size_,
const Tensor &max_indices,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights_opt,
int64_t padding_idx) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
// indices, offsets and offset2bag are assumed having correct dtypes and
// contiguous here due to the checks in _embedding_bag_backward in
// EmbeddingBag.cpp.
// Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
// for more details.
Tensor grad = grad_.contiguous();
auto indices_arg = TensorArg(indices, "indices", 1);
auto grad_arg = TensorArg(grad, "grad", 1);
checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
switch (mode) {
case MODE_SUM:
case MODE_MEAN:
if (mode == MODE_MEAN)
AT_ASSERT(!per_sample_weights.defined());
return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag,
bag_size_, num_weights, scale_grad_by_freq, mode,
per_sample_weights, padding_idx);
case MODE_MAX:
AT_ASSERT(!per_sample_weights.defined());
return embedding_bag_backward_cuda_max(grad, max_indices, num_weights,
padding_idx);
default:
AT_ERROR(
"Unknown mode for embedding_bag_backward_cuda ", mode);
}
}
template <typename scalar_t>
__inline__ __device__
static scalar_t warpReduceSum(scalar_t val) {
for (int offset = C10_WARP_SIZE/2; offset > 0; offset /= 2)
val += WARP_SHFL_DOWN(val, offset);
return val;
}
template <typename scalar_t, typename index_t>
__global__ static void _embedding_bag_per_sample_weights_backward_kernel(
const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
const index_t* indices, // contiguous
const index_t* offset2bag, // contiguous
int64_t num_samples,
int64_t embedding_features,
scalar_t* output,
index_t padding_idx) {
using accscalar_t = acc_type<scalar_t, true>;
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
const int warp = idx / C10_WARP_SIZE;
const int thread_in_warp = idx % C10_WARP_SIZE;
const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE;
// Each warp is responsible for the accumulation of one sample.
// This involves doing one dot product between grad[bag_idx] and weight[embedding_idx].
for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) {
accscalar_t result = 0.;
const int bag_idx = (int)offset2bag[sample_idx];
const int embedding_idx = (int)indices[sample_idx];
if (embedding_idx != padding_idx) {
for (int feature_idx = thread_in_warp; feature_idx < embedding_features;
feature_idx += C10_WARP_SIZE) {
result +=
grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] *
weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx];
}
}
result = warpReduceSum<accscalar_t>(result);
if (thread_in_warp == 0) {
output[sample_idx] = result;
}
}
}
Tensor _embedding_bag_per_sample_weights_backward_cuda(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices_,
const Tensor& offsets_,
const Tensor& offset2bag,
int64_t mode,
int64_t padding_idx) {
TORCH_CHECK(
mode == MODE_SUM,
"embedding_bag_backward: per_sample_weights only supported for mode='sum'");
AT_ASSERT(grad.dim() == 2);
auto embedding_features = grad.size(1);
Tensor indices, offsets;
std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
AT_ASSERT(indices.dim() == 1);
auto num_samples = indices.size(0);
AT_ASSERT(weight.dim() == 2);
AT_ASSERT(weight.size(1) == embedding_features);
const int threads_per_block = 512;
const int warps_per_block = threads_per_block / C10_WARP_SIZE;
dim3 block(threads_per_block);
dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);
auto output = at::empty({num_samples}, grad.options());
// Early return when there is no samples in the batch. This saves unnecesary kernel
// launch, but also prevents cudaGetLastError() to complain about invalid launch args
if (num_samples == 0) {
return output;
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
_embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
weight.data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
indices.data_ptr<index_t>(),
offset2bag.data_ptr<index_t>(),
num_samples,
embedding_features,
output.data_ptr<scalar_t>(),
padding_idx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
);
return output;
}
}
}