@@ -99,6 +99,8 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
99
99
}
100
100
};
101
101
102
+ enum PickOpMode {kWrap , kClip };
103
+
102
104
struct PickParam : public dmlc ::Parameter<PickParam> {
103
105
dmlc::optional<int > axis;
104
106
int mode;
@@ -112,6 +114,14 @@ struct PickParam : public dmlc::Parameter<PickParam> {
112
114
DMLC_DECLARE_FIELD (keepdims).set_default (false )
113
115
.describe (" If true, the axis where we pick the elements is left "
114
116
" in the result as dimension with size one." );
117
+ DMLC_DECLARE_FIELD (mode)
118
+ .add_enum (" wrap" , kWrap )
119
+ .add_enum (" clip" , kClip )
120
+ .set_default (kClip )
121
+ .describe (" Specify how out-of-bound indices behave. Default is \" clip\" ."
122
+ " \" clip\" means clip to the range. So, if all indices mentioned are too large,"
123
+ " they are replaced by the index that addresses the last element along an axis. "
124
+ " \" wrap\" means to wrap around." );
115
125
}
116
126
};
117
127
@@ -1108,7 +1118,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
1108
1118
const std::vector<NDArray>& outputs);
1109
1119
1110
1120
/* ! \brief index element from array along axes */
1111
- template <int ndim>
1121
+ template <int ndim, bool clip = true >
1112
1122
struct pick {
1113
1123
template <typename DType, typename IType>
1114
1124
MSHADOW_XINLINE static void Map (int i, DType* out, const DType* a,
@@ -1117,15 +1127,20 @@ struct pick {
1117
1127
mshadow::Shape<ndim> sshape) {
1118
1128
using namespace broadcast ;
1119
1129
int j = static_cast <int >(idx[i]);
1120
- if (j < 0 ) j = 0 ;
1121
- else if (j >= M) j = M-1 ;
1130
+ if (clip) {
1131
+ if (j <= 0 ) j = 0 ;
1132
+ else if (j >= M) j = M - 1 ;
1133
+ } else {
1134
+ j = j % M;
1135
+ j += (j < 0 ) ? M : 0 ;
1136
+ }
1122
1137
j = ravel (unravel (i, sshape), bshape) + j*stride;
1123
1138
out[i] = a[j];
1124
1139
}
1125
1140
};
1126
1141
1127
1142
/* ! \brief index element from array along axes */
1128
- template <int ndim>
1143
+ template <int ndim, bool clip = true >
1129
1144
struct pick_grad {
1130
1145
template <typename DType, typename IType>
1131
1146
MSHADOW_XINLINE static void Map (int i, DType* igrad, const DType* ograd,
@@ -1134,8 +1149,13 @@ struct pick_grad {
1134
1149
mshadow::Shape<ndim> sshape) {
1135
1150
using namespace broadcast ;
1136
1151
int j = static_cast <int >(idx[i]);
1137
- if (j < 0 ) j = 0 ;
1138
- else if (j >= M) j = M-1 ;
1152
+ if (clip) {
1153
+ if (j <= 0 ) j = 0 ;
1154
+ else if (j >= M) j = M - 1 ;
1155
+ } else {
1156
+ j = j % M;
1157
+ j += (j < 0 ) ? M : 0 ;
1158
+ }
1139
1159
j = ravel (unravel (i, sshape), bshape) + j*stride;
1140
1160
igrad[j] += ograd[i];
1141
1161
}
@@ -1195,15 +1215,28 @@ void PickOpForward(const nnvm::NodeAttrs& attrs,
1195
1215
1196
1216
MSHADOW_TYPE_SWITCH (outputs[0 ].type_flag_ , DType, { // output type
1197
1217
MSHADOW_TYPE_SWITCH (inputs[1 ].type_flag_ , IType, { // index type
1198
- if (trailing == 1 ) {
1199
- Kernel<pick<2 >, xpu>::Launch (s, outputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1200
- inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1201
- M, 1 , Shape2 (leading, M), Shape2 (leading, 1 ));
1218
+ if (param.mode == kWrap ) {
1219
+ if (trailing == 1 ) {
1220
+ Kernel<pick<2 , false >, xpu>::Launch (s, outputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1221
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1222
+ M, 1 , Shape2 (leading, M), Shape2 (leading, 1 ));
1223
+ } else {
1224
+ Kernel<pick<3 , false >, xpu>::Launch (s, outputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1225
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1226
+ M, trailing, Shape3 (leading, M, trailing),
1227
+ Shape3 (leading, 1 , trailing));
1228
+ }
1202
1229
} else {
1203
- Kernel<pick<3 >, xpu>::Launch (s, outputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1204
- inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1205
- M, trailing, Shape3 (leading, M, trailing),
1206
- Shape3 (leading, 1 , trailing));
1230
+ if (trailing == 1 ) {
1231
+ Kernel<pick<2 , true >, xpu>::Launch (s, outputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1232
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1233
+ M, 1 , Shape2 (leading, M), Shape2 (leading, 1 ));
1234
+ } else {
1235
+ Kernel<pick<3 , true >, xpu>::Launch (s, outputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1236
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1237
+ M, trailing, Shape3 (leading, M, trailing),
1238
+ Shape3 (leading, 1 , trailing));
1239
+ }
1207
1240
}
1208
1241
});
1209
1242
});
@@ -1230,15 +1263,28 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
1230
1263
MSHADOW_TYPE_SWITCH (outputs[0 ].type_flag_ , DType, { // output type
1231
1264
MSHADOW_TYPE_SWITCH (inputs[1 ].type_flag_ , IType, { // index type
1232
1265
if (req[0 ] != kAddTo ) outputs[0 ].FlatTo1D <xpu, DType>(s) = 0 ;
1233
- if (trailing == 1 ) {
1234
- Kernel<pick_grad<2 >, xpu>::Launch (s, inputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1235
- inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1236
- M, 1 , Shape2 (leading, M), Shape2 (leading, 1 ));
1266
+ if (param.mode == kWrap ) {
1267
+ if (trailing == 1 ) {
1268
+ Kernel<pick_grad<2 , false >, xpu>::Launch (s, inputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1269
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1270
+ M, 1 , Shape2 (leading, M), Shape2 (leading, 1 ));
1271
+ } else {
1272
+ Kernel<pick_grad<3 , false >, xpu>::Launch (s, inputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1273
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1274
+ M, trailing, Shape3 (leading, M, trailing),
1275
+ Shape3 (leading, 1 , trailing));
1276
+ }
1237
1277
} else {
1238
- Kernel<pick_grad<3 >, xpu>::Launch (s, inputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1239
- inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1240
- M, trailing, Shape3 (leading, M, trailing),
1241
- Shape3 (leading, 1 , trailing));
1278
+ if (trailing == 1 ) {
1279
+ Kernel<pick_grad<2 , true >, xpu>::Launch (s, inputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1280
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1281
+ M, 1 , Shape2 (leading, M), Shape2 (leading, 1 ));
1282
+ } else {
1283
+ Kernel<pick_grad<3 , true >, xpu>::Launch (s, inputs[0 ].Size (), outputs[0 ].dptr <DType>(),
1284
+ inputs[0 ].dptr <DType>(), inputs[1 ].dptr <IType>(),
1285
+ M, trailing, Shape3 (leading, M, trailing),
1286
+ Shape3 (leading, 1 , trailing));
1287
+ }
1242
1288
}
1243
1289
});
1244
1290
});
0 commit comments