forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathIndexing.cpp
572 lines (519 loc) · 21.6 KB
/
Indexing.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
// Indexing tensors by by tensors
//
// This corresponds to "advanced indexing" in NumPy. The two operations are:
//
// index(Tensor self, indices) -> Tensor
// index_put_(Tensor self, indices, value, accumulate=false)
//
// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null
// tensors signify that the dimension is not indexed.
//
// All indexes are broadcast together and iterated as *one*. From NumPy:
//
// result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
// ..., ind_N[i_1, ..., i_M]]
//
// Note 1: ByteTensors expand to index as many dimensions as there are in the
// mask.
//
// Note 2: The behavior is more complicated when the index tensors are not all
// adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
// tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
//
// The code contains two implementations of indexing. The more efficient
// implementation treats indexing like an elementwise operation over the
// tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does
// not work for index_put_ with accumulate=True. The other implementation
// combines the indexed tensors into a single linear index that is used
// with Tensor.put_. This is used for index_put_ with accumulate=True.
//
// The more efficient implementation takes the following steps for the
// above operation:
//
// 1) Broadcast ind_1, ind_2, ind_3 together to a common shape
// 2) Record x.stride(i) for each indexed dimension `i`
// 3) Replace the indexed subspace of `x` with the shape of the corresponding
// subspace of `result` but with stride 0
// 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so
// that their shape is compatible with the result shape
//
// The CPU or CUDA kernel then computes element-wise over the broadcasted
// and restrided result, x, ind_1, ind_2, etc.:
//
// result[...] = *(&x[...] +
// ind_1[...] * x.stride(1) +
// ind_2[...] * x.stride(2) +
// ...)
//
// where & and * represent the C-style address-of and indirection operations.
#include <ATen/native/Indexing.h>
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/LegacyTHFunctions.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
#include <algorithm>
#include <functional>
#include <numeric>
#include <vector>
namespace at { namespace native {
DEFINE_DISPATCH(index_stub);
DEFINE_DISPATCH(index_put_stub);
[[noreturn]]
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
std::stringstream ss;
ss << "The shape of the mask " << mask.sizes() << " at index " << maskIdx;
ss << " does not match the shape of the indexed tensor " << self.sizes();
ss << " at index " << idx;
AT_INDEX_ERROR(ss.str());
}
static void checkIndexTensorTypes(TensorList indices) {
for (auto& tensor : indices) {
if (tensor.defined()) {
auto scalarType = tensor.scalar_type();
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
AT_INDEX_ERROR("tensors used as indices must be long, byte or bool tensors");
}
}
}
}
static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
// Expands ByteTensor (masks) or BoolTensor (masks) into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (auto & index : indices) {
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (int64_t j = 0; j < index.dim(); j++) {
int64_t srcIdx = result.size() + j;
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
auto special_empty = false;
for (int64_t j = 0; j < index.dim(); j++) {
if (special_empty) {
// We can't call select on an empty tensor so we just create an empty
// tensor.
result.emplace_back(at::empty({0}, nonzero.options()));
} else {
result.emplace_back(nonzero.select(1, j));
}
}
} else {
result.emplace_back(index);
}
}
return result;
}
static bool hasContiguousSubspace(TensorList tl) {
// true if all the non-null tensors are adjacent
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
auto it = std::find_if(start, stop.base(), isNull);
return it == stop.base();
}
// Transposes the tensor and indices together so that all the non-null indices
// index the first k dimensions of the tensor. Returns the transposed tensor
// and the reordered indices. For example:
// transposeToFront(tensor, {nullptr, a, nullptr, b})
// returns
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
static std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(Tensor self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
for (int64_t i = 0; i < self.dim(); i++) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (int64_t i = 0; i < self.dim(); i++) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
}
static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
// computes the stride as if tensor were contigous
auto sizes = tensor.sizes();
std::vector<int64_t> stride(tensor.dim());
stride[tensor.dim() - 1] = 1;
std::partial_sum(sizes.rbegin(), sizes.rend() - 1, stride.rbegin() + 1, std::multiplies<int64_t>());
return stride;
}
// Unsqueezes src `before` times at the front and `after` times at the end
static Tensor unsqueezeN(const Tensor & src, int64_t before, int64_t after) {
auto srcSizes = src.sizes();
auto nDim = src.dim();
std::vector<int64_t> sizes(nDim + before + after, 1);
for (int64_t i = 0; i < nDim; i++) {
sizes[i + before] = srcSizes[i];
}
return src.view(sizes);
}
static Tensor wrapIndexOnce(const Tensor & index, int64_t dim, int64_t dim_size) {
if (index.numel() != 0) {
auto max_idx = index.max().item<int64_t>();
auto min_idx = index.min().item<int64_t>();
if (max_idx >= dim_size) {
AT_INDEX_ERROR("index ", max_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
}
if (min_idx < -dim_size) {
AT_INDEX_ERROR("index ", min_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
}
}
return index.remainder(dim_size);
}
static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
auto strides = computeLinearStride(src);
// Compute the linear index by multiplying the indexing tensors by the
// stride and summing them. All the indexing tensors have the same shape at
// this point. We also compute the number of dimensions before and after that
// are not being index.
Tensor linearIndex;
int64_t emptyBefore = 0, emptyAfter = 0, nElemBefore = 1, nElemAfter = 1;
for (int64_t i = 0; i < src.dim(); i++) {
if (indices[i].defined()) {
// Cast index to the longType matching src's backend
// This allows us to support ie indexing a cuda tensor with a cpu tensor
Tensor index = (wrapIndexOnce(indices[i], i, src.size(i)) * strides[i]).to(kLong);
if (linearIndex.defined()) {
linearIndex += index;
} else {
linearIndex = index;
}
} else if (linearIndex.defined()) {
emptyAfter++;
nElemAfter *= src.size(i);
} else {
emptyBefore++;
nElemBefore *= src.size(i);
}
}
// Compute the linear indices for the parts of the tensor not being indexed
Tensor beforeIndex;
if (emptyBefore > 0) {
auto index = at::arange(0, nElemBefore, src.options().dtype(kLong)) * strides[emptyBefore - 1];
index = index.view(src.sizes().slice(0, emptyBefore));
beforeIndex = unsqueezeN(index, 0, linearIndex.dim() + emptyAfter);
}
Tensor afterIndex;
if (emptyAfter > 0) {
auto index = at::arange(0, nElemAfter, src.options().dtype(kLong));
index = index.view(src.sizes().slice(src.dim() - emptyAfter, emptyAfter));
afterIndex = unsqueezeN(index, linearIndex.dim() + emptyBefore, 0);
}
// Sum with broadcasting to compute the full index
linearIndex = unsqueezeN(linearIndex, emptyBefore, emptyAfter);
if (beforeIndex.defined()) {
linearIndex = linearIndex + beforeIndex;
}
if (afterIndex.defined()) {
linearIndex = linearIndex + afterIndex;
}
return linearIndex;
}
static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
// next broadcast all index tensors together
indices = expand_outplace(indices);
// add missing null Tensors so that it matches self.dim()
while (indices.size() < (size_t)self.dim()) {
indices.emplace_back();
}
// if the non-null indices are not all adjacent, transpose self and indices
// together so that they're adjacent at the front
if (!hasContiguousSubspace(indices)) {
std::tie(self, indices) = transposeToFront(self, indices);
}
auto linearIndex = computeLinearIndex(self, indices);
return std::make_tuple(self, linearIndex);
}
static bool all_strides_match(TensorList tensors) {
AT_ASSERT(tensors.size() >= 1);
auto strides = tensors[0].strides();
for (auto& tensor : tensors.slice(1)) {
if (!strides.equals(tensor.strides())) {
return false;
}
}
return true;
}
static std::string shapes_as_str(TensorList tensors) {
std::ostringstream os;
bool first = true;
for (auto& tensor : tensors) {
if (tensor.defined()) {
if (!first) {
os << ", ";
}
os << tensor.sizes();
first = false;
}
}
return os.str();
}
struct AdvancedIndex {
AdvancedIndex(const Tensor& src, TensorList indices);
Tensor src;
std::vector<Tensor> indices;
DimVector indexed_sizes;
DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
// Replace indexed dimensions in src with stride 0 and the size of the result tensor.
// The offset in these dimensions is computed by the kernel using the index tensor's
// values and the stride of src. The new shape is not meaningful. It's used to make
// the shape compatible with the result tensor.
static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
IntArrayRef replacement_shape) {
auto shape = DimVector(src.sizes());
auto strides = DimVector(src.strides());
int64_t end = dims_before + dims_indexed;
shape.erase(shape.begin() + dims_before, shape.begin() + end);
strides.erase(strides.begin() + dims_before, strides.begin() + end);
shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
return src.as_strided(shape, strides);
}
// Add dimensions of size 1 to an index tensor so that it can be broadcast to the result
// shape and iterated over element-wise like the result tensor and the restrided src.
static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
auto orig_shape = index.sizes();
auto shape = DimVector();
shape.append(dims_before, 1);
shape.append(orig_shape.begin(), orig_shape.end());
shape.append(dims_after, 1);
return index.reshape(shape);
}
AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
{
int64_t element_size_bytes = src.element_size();
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
IntArrayRef replacement_shape;
for (size_t dim = 0; dim < indices_list.size(); dim++) {
if (!indices_list[dim].defined()) {
if (dims_indexed == 0) {
dims_before++;
} else {
dims_after++;
}
} else {
dims_indexed++;
replacement_shape = indices_list[dim].sizes();
indexed_sizes.push_back(src.size(dim));
indexed_strides.push_back(src.stride(dim) * element_size_bytes);
}
}
// Check if the indexed subspace contains a dim of size 0, but the replacement
// shape does not. This implies that an index is out of bounds, because there
// is no number that's a valid index for an empty tensor. Normally, out of
// bounds is handled in the indexing kernel, but this case fails earlier in
// restride_src with an unhelpful error message.
if (std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end() &&
std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end()) {
AT_INDEX_ERROR("index is out of bounds for dimension with size 0");
}
this->dims_before = dims_before;
this->dims_after = dims_after;
this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
for (auto& index : indices_list) {
if (index.defined()) {
indices.push_back(reshape_indexer(index, dims_before, dims_after));
}
}
// For CUDA tensors, force all index tensors to have the same striding to
// simplify the CUDA kernel.
if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) {
if (!all_strides_match(indices)) {
for (size_t i = 0; i < indices.size(); i++) {
indices[i] = indices[i].contiguous();
}
}
}
}
static AdvancedIndex make_info(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
// next broadcast all index tensors together
try {
indices = expand_outplace(indices);
} catch (std::exception& e) {
AT_INDEX_ERROR("shape mismatch: indexing tensors could not be broadcast together"
" with shapes ", shapes_as_str(indices));
}
// add missing null Tensors so that it matches self.dim()
while (indices.size() < (size_t)self.dim()) {
indices.emplace_back();
}
// if the non-null indices are not all adjacent, transpose self and indices
// together so that they're adjacent at the front
if (!hasContiguousSubspace(indices)) {
std::tie(self, indices) = transposeToFront(self, indices);
}
// Ensure indices are on the same device as self
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i].defined() && indices[i].device() != self.device()) {
indices[i] = indices[i].to(self.device());
}
}
return AdvancedIndex(self, indices);
}
static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
auto builder = TensorIterator::Builder();
builder.dont_compute_common_dtype();
builder.add_output(Tensor(), info.src.type().backend(), info.src.scalar_type());
builder.add_input(info.src);
for (auto& index : info.indices) {
builder.add_input(index);
}
return builder.build();
}
static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
if (!is_expandable_to(value.sizes(), info.src.sizes())) {
AT_ERROR("shape mismatch: value tensor of shape ", value.sizes(),
" cannot be broadcast to indexing result of shape ", info.src.sizes());
}
auto builder = TensorIterator::Builder();
builder.dont_compute_common_dtype();
builder.dont_resize_outputs();
builder.add_output(info.src);
builder.add_input(value, info.src.type().backend(), info.src.scalar_type());
for (auto& index : info.indices) {
builder.add_input(index);
}
return builder.build();
}
Tensor index(const Tensor & self, TensorList indices) {
if (indices.size() > (size_t)self.dim()) {
AT_INDEX_ERROR("too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}
auto info = make_info(self, indices);
auto iter = make_index_iterator(info);
index_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides);
return iter->output();
}
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
return self.clone().index_put_(indices, value, accumulate);
}
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
if (indices.size() > (size_t)self.dim()) {
AT_INDEX_ERROR("too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}
if (accumulate && self.type().device_type() == kCUDA) {
Tensor src, linearIndex, expandedValue;
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
std::tie(expandedValue) = expand_inplace(linearIndex, value);
return src.put_(linearIndex, expandedValue, true);
}
auto info = make_info(self, indices);
auto iter = make_index_put_iterator(info, value);
index_put_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides, accumulate);
return self;
}
Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
dim = maybe_wrap_dim(dim, self.dim());
if (index.dim() >= 2) {
AT_INDEX_ERROR("index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")");
}
int64_t numIndices = index.numel();
if (source.dim() == 0 && numIndices != 1) {
AT_INDEX_ERROR("index_copy_(): When source is scalar, index should have one element (got ", numIndices, ")");
}
if (index.scalar_type() != ScalarType::Long) {
AT_INDEX_ERROR("index_copy_(): Expected LongTensor for index");
}
// Check that source and destination slices have the same size
auto selfSlicedSizes = self.sizes().vec();
if (selfSlicedSizes.size() > 0) {
selfSlicedSizes.erase(selfSlicedSizes.begin() + dim);
}
auto sourceSlicedSizes = source.sizes().vec();
if (sourceSlicedSizes.size() > 0) {
sourceSlicedSizes.erase(sourceSlicedSizes.begin() + dim);
}
if (selfSlicedSizes.size() != sourceSlicedSizes.size() ||
!std::equal(selfSlicedSizes.begin(), selfSlicedSizes.end(),
sourceSlicedSizes.begin())) {
std::stringstream ss;
ss << "index_copy_(): Source/destination tensor must have same slice shapes. ";
ss << "Destination slice shape: " << selfSlicedSizes << " at dimension " << dim;
ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0.";
AT_ERROR(ss.str());
}
if (source.dim() > 0 && numIndices != source.size(dim)) {
AT_INDEX_ERROR(
"index_copy_(): Number of indices (", numIndices, ") should be equal to source.size(dim) (", source.size(dim), ")");
}
return at::legacy::th::_th_index_copy_(self, dim, index, source);
}
Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().index_copy_(dim, index, source);
}
Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().index_add_(dim, index, source);
}
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
return self.clone().index_fill_(dim, index, source);
}
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().index_fill_(dim, index, source);
}
Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().scatter_(dim, index, source);
}
Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
return self.clone().scatter_(dim, index, source);
}
Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().scatter_add_(dim, index, source);
}
Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) {
Tensor _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
return _self.clone().masked_scatter_(_mask, source);
}
Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) {
Tensor _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
return _self.clone().masked_fill_(mask, source);
}
Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) {
Tensor _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
return _self.clone().masked_fill_(mask, source);
}
Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
// special case scalar input and/or index
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());
if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes());
Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong));
int64_t n_above = grad.numel();
int64_t n_below = 1;
if (dim < 0) dim += self.ndimension();
for (int i=0; i<self.ndimension(); i++) {
n_above /= grad.size(i);
if (i == dim) {
sparse_ind[i] = index.reshape(-1);
} else {
sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below);
}
n_below *= grad.size(i);
}
return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
}
}} // at::native