forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindexing_op.cc
771 lines (655 loc) · 28.9 KB
/
indexing_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file indexing_op.cc
* \brief
* \author Siyi Li, Chi Zhang
*/
#include "./indexing_op.h"
namespace mxnet {
namespace op {
template<>
void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp) return;
using namespace rowsparse;
using namespace mxnet_op;
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
size_t out_size = output.shape_.Size();
MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
Fill<false>(s, TBlob(output.dptr<DType>(), mshadow::Shape1(out_size),
cpu::kDevMask), kWriteTo, 0);
})
return;
}
// check out-of-bound indices
bool is_valid = true;
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
DType min = 0;
DType max = static_cast<DType>(weight.shape()[0] - 1);
// check with single thread is faster since data is small
DType* data_ptr = data.dptr<DType>();
size_t data_size = data.shape_.Size();
for (size_t i = 0; i < data_size; i++) {
if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false;
}
})
CHECK(is_valid) << "SparseEmbedding input contains data out of bound";
// the weight is actually dense
if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
EmbeddingOpForwardDnsImpl<cpu>(s, data, weight.data(), req, output);
} else {
EmbeddingOpForwardRspImpl<cpu>(s, data, weight, req, output);
}
}
template<>
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
const OpContext& ctx,
const TBlob& ograd,
const TBlob& data,
const OpReqType req,
const NDArray& output) {
using namespace mshadow;
using namespace mxnet_op;
using namespace mshadow::expr;
using namespace rowsparse;
using nnvm::dim_t;
if (req == kNullOp) return;
CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
<< "weight gradient calculation with req != write";
// Request temporary storage for marking non-zero rows and prefix sum
Stream<cpu> *s = ctx.get_stream<cpu>();
dim_t num_rows = output.shape()[0];
dim_t row_length = output.shape()[1];
size_t workspace_size = num_rows * sizeof(dim_t);
Tensor<cpu, 1, char> workspace =
ctx.requested[embedding::kTempSpace].get_space_typed<cpu, 1, char>(
Shape1(workspace_size), s);
dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
// prefix sum array re-uses the row_flg array temp space
dim_t* prefix_sum = row_flg;
dim_t data_size = static_cast<dim_t>(data.shape_.Size());
MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
// mark row flags
Fill<false>(s, TBlob(row_flg, Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
Kernel<MarkRowFlgKernel, cpu>::Launch(s, data_size, row_flg, data.dptr<IType>());
// calculate inclusive prefix sum
// TODO(haibin) ideally this is should be done in parallel
prefix_sum[0] = row_flg[0];
for (dim_t i = 1; i < num_rows; i++) {
prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
}
// total number of non-zero rows
dim_t nnr = prefix_sum[num_rows - 1];
if (nnr == 0) {
FillZerosRspImpl(s, output);
return;
}
output.CheckAndAlloc({Shape1(nnr)});
RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
// fill row_idx array of output matrix, using the row_flg values
Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows,
grad_row_idx, prefix_sum, num_rows);
// prefill with zeros
DType* grad_data = output.data().dptr<DType>();
Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length),
cpu::kDevMask), kWriteTo, 0);
// add the final gradients
const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
dim_t segment_len = (nnr + num_threads - 1) / num_threads;
Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
ograd.dptr<DType>(), row_length,
data.dptr<IType>(), data_size, segment_len,
num_rows);
});
});
});
}
template<typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
#pragma omp parallel for
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
#pragma omp atomic
out[offset + j] += data[i * K + j];
}
}
}
template<typename DType, typename IType>
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
}
}
}
DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(SparseEmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
DMLC_REGISTER_PARAMETER(ScatterNDParam);
NNVM_REGISTER_OP(Embedding)
MXNET_ADD_SPARSE_OP_ALIAS(Embedding)
.describe(R"code(Maps integer indices to vector representations (embeddings).
This operator maps words to real-valued vectors in a high-dimensional space,
called word embeddings. These embeddings can capture semantic and syntactic properties of the words.
For example, it has been noted that in the learned embedding spaces, similar words tend
to be close to each other and dissimilar words far apart.
For an input array of shape (d1, ..., dK),
the shape of an output array is (d1, ..., dK, output_dim).
All the input values should be integers in the range [0, input_dim).
If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be
(ip0, op0).
By default, if any index mentioned is too large, it is replaced by the index that addresses
the last vector in an embedding matrix.
Examples::
input_dim = 4
output_dim = 5
// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3)
y = [[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., 18., 19.]]
// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)]
x = [[ 1., 3.],
[ 0., 2.]]
// Mapped input x to its vector representation y.
Embedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.],
[ 15., 16., 17., 18., 19.]],
[[ 0., 1., 2., 3., 4.],
[ 10., 11., 12., 13., 14.]]]
The storage type of weight can be either row_sparse or default.
.. Note::
If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.incubator.apache.org/api/python/optimization/optimization.html
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<EmbeddingParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
})
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape<EmbeddingParam>)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<EmbeddingParam>)
.set_attr<FInferStorageType>("FInferStorageType", EmbeddingOpForwardStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpForward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_Embedding", n, ograds,
{n->inputs[0]}, n->attrs.dict);
})
.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.")
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.")
.add_arguments(EmbeddingParam::__FIELDS__());
NNVM_REGISTER_OP(_contrib_SparseEmbedding)
.describe(R"code(Maps integer indices to vector representations (embeddings).
note:: ``contrib.SparseEmbedding`` is deprecated, use ``Embedding`` instead.
This operator maps words to real-valued vectors in a high-dimensional space,
called word embeddings. These embeddings can capture semantic and syntactic properties of the words.
For example, it has been noted that in the learned embedding spaces, similar words tend
to be close to each other and dissimilar words far apart.
For an input array of shape (d1, ..., dK),
the shape of an output array is (d1, ..., dK, output_dim).
All the input values should be integers in the range [0, input_dim).
If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be
(ip0, op0).
The storage type of the gradient will be `row_sparse`.
.. Note::
`SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k).
The operator is available on both CPU and GPU.
When `deterministic` is set to `True`, the accumulation of gradients follows a
deterministic order if a feature appears multiple times in the input. However, the
accumulation is usually slower when the order is enforced on GPU.
When the operator is used on the GPU, the recommended value for `deterministic` is `True`.
Examples::
input_dim = 4
output_dim = 5
// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3)
y = [[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., 18., 19.]]
// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)]
x = [[ 1., 3.],
[ 0., 2.]]
// Mapped input x to its vector representation y.
SparseEmbedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.],
[ 15., 16., 17., 18., 19.]],
[[ 0., 1., 2., 3., 4.],
[ 10., 11., 12., 13., 14.]]]
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SparseEmbeddingParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape<SparseEmbeddingParam>)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<SparseEmbeddingParam>)
.set_attr<FInferStorageType>("FInferStorageType", SparseEmbeddingOpForwardStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds,
{n->inputs[0]}, n->attrs.dict);
})
.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.")
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.")
.add_arguments(EmbeddingParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_Embedding)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<EmbeddingParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FInferStorageType>("FInferStorageType", EmbeddingOpBackwardStorageType)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", EmbeddingOpBackwardEx<cpu>);
NNVM_REGISTER_OP(_backward_SparseEmbedding)
.set_attr_parser(ParamParser<SparseEmbeddingParam>)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FInferStorageType>("FInferStorageType", SparseEmbeddingOpBackwardStorageType)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpBackwardEx<cpu>);
NNVM_REGISTER_OP(take)
.describe(R"code(Takes elements from an input array along the given axis.
This function slices the input array along a particular axis with the provided indices.
Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis
dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them
in an output tensor of rank q + (r - 1).
Examples::
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
take(x, [1]) = [ 5. ]
// The other trivial case, axis=-1, take the third element along the first axis
take(x, [3], axis=-1, mode='clip') = [ 6. ]
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0
take(x, [[0,1],[1,2]]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 3., 4.],
[ 5., 6.]]]
// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around).
// Along axis 1
take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.],
[ 3., 4.]],
[[ 3., 4.],
[ 5., 6.]]]
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", TakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", TakeOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_take", n, ograds,
{n->inputs[1]}, n->attrs.dict);
})
.add_argument("a", "NDArray-or-Symbol", "The input array.")
.add_argument("indices", "NDArray-or-Symbol", "The indices of the values to be extracted.")
.add_arguments(TakeParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", TakeOpBackward<cpu>);
NNVM_REGISTER_OP(batch_take)
.describe(R"code(Takes elements from a data batch.
.. note::
`batch_take` is deprecated. Use `pick` instead.
Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be
an output array of shape ``(i0,)`` with::
output[i] = input[i, indices[i]]
Examples::
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
// takes elements with specified indices
batch_take(x, [0,1,0]) = [ 1. 4. 5.]
)code" ADD_FILELINE)
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", BatchTakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", BatchTakeOpType)
.set_attr<FCompute>("FCompute<cpu>", BatchTakeOpForward<cpu>)
.add_argument("a", "NDArray-or-Symbol", "The input array")
.add_argument("indices", "NDArray-or-Symbol", "The index array");
NNVM_REGISTER_OP(one_hot)
.describe(R"code(Returns a one-hot array.
The locations represented by `indices` take value `on_value`, while all
other locations take value `off_value`.
`one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result
in an output array of shape ``(i0, i1, d)`` with::
output[i,j,:] = off_value
output[i,j,indices[i,j]] = on_value
Examples::
one_hot([1,0,2,0], 3) = [[ 0. 1. 0.]
[ 1. 0. 0.]
[ 0. 0. 1.]
[ 1. 0. 0.]]
one_hot([1,0,2,0], 3, on_value=8, off_value=1,
dtype='int32') = [[1 8 1]
[8 1 1]
[1 1 8]
[8 1 1]]
one_hot([[1,0],[1,0],[2,0]], 3) = [[[ 0. 1. 0.]
[ 1. 0. 0.]]
[[ 0. 1. 0.]
[ 1. 0. 0.]]
[[ 0. 0. 1.]
[ 1. 0. 0.]]]
)code" ADD_FILELINE)
.set_num_outputs(1)
.set_num_inputs(1)
.set_attr_parser(ParamParser<OneHotParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", OneHotOpShape)
.set_attr<nnvm::FInferType>("FInferType", OneHotOpType)
.set_attr<FCompute>("FCompute<cpu>", OneHotOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value")
.add_arguments(OneHotParam::__FIELDS__());
NNVM_REGISTER_OP(gather_nd)
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose
shape is defined by `indices`.
Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`,
where `M <= N`. If `M == N`, output shape will simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] = data[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}]
Examples::
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
gather_nd(data, indices) = [2, 3, 0]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", GatherNDShape)
.set_attr<nnvm::FInferType>("FInferType", GatherNDType)
.set_attr<FCompute>("FCompute<cpu>", GatherNDForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_backward_gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices");
NNVM_REGISTER_OP(scatter_nd)
.describe(R"code(Scatters data into a new tensor according to indices.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
all other entries in output are 0.
.. warning::
If the indices have duplicates, the result will be non-deterministic and
the gradient of `scatter_nd` will not be correct!!
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
shape = (2, 2)
scatter_nd(data, indices, shape) = [[0, 0], [2, 3]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", ScatterNDForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_gather_nd)
.describe(R"code(Accumulates data according to indices and get the result. It's the backward of
`gather_nd`.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
all other entries in output are 0 or the original value if AddTo is triggered.
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
shape = (2, 2)
_backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd
# The difference between scatter_nd and scatter_nd_acc is the latter will accumulate
# the values that point to the same index.
data = [2, 3, 0]
indices = [[1, 1, 0], [1, 1, 0]]
shape = (2, 2)
_backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
NNVM_REGISTER_OP(_scatter_set_nd)
.describe(R"code(This operator has the same functionality as scatter_nd
except that it does not reset the elements not indexed by the input
index `NDArray` in the input data `NDArray`. output should be explicitly
given and be the same as lhs.
.. note:: This operator is for internal use only.
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
out = [[1, 1], [1, 1]]
_scatter_set_nd(lhs=out, rhs=data, indices=indices, out=out)
out = [[0, 1], [2, 3]]
)code")
.set_num_outputs(1)
.set_num_inputs(3)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
std::vector<TShape> tmp_in_attrs = {in_attrs->at(1), in_attrs->at(2)};
if (!ScatterNDShape(attrs, &tmp_in_attrs, out_attrs)) {
return false;
}
SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_in_attrs[0]);
SHAPE_ASSIGN_CHECK(*in_attrs, 2, tmp_in_attrs[1]);
SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return true;
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
std::vector<int> tmp_in_attrs = {in_attrs->at(1), in_attrs->at(2)};
if (!ScatterNDType(attrs, &tmp_in_attrs, out_attrs)) {
return false;
}
TYPE_ASSIGN_CHECK(*in_attrs, 1, tmp_in_attrs[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 2, tmp_in_attrs[1]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return true;
})
.set_attr<FCompute>("FCompute<cpu>", ScatterSetNDForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.add_argument("lhs", "NDArray-or-Symbol", "source input")
.add_argument("rhs", "NDArray-or-Symbol", "value to assign")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
} // namespace op
} // namespace mxnet