Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 05e2915

Browse files
committed
Updates pick operation to also handle negative indices
1 parent 697424c commit 05e2915

File tree

3 files changed

+111
-49
lines changed

3 files changed

+111
-49
lines changed

src/operator/tensor/broadcast_reduce_op.h

+68-22
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
9999
}
100100
};
101101

102+
enum PickOpMode {kWrap, kClip};
103+
102104
struct PickParam : public dmlc::Parameter<PickParam> {
103105
dmlc::optional<int> axis;
104106
int mode;
@@ -112,6 +114,14 @@ struct PickParam : public dmlc::Parameter<PickParam> {
112114
DMLC_DECLARE_FIELD(keepdims).set_default(false)
113115
.describe("If true, the axis where we pick the elements is left "
114116
"in the result as dimension with size one.");
117+
DMLC_DECLARE_FIELD(mode)
118+
.add_enum("wrap", kWrap)
119+
.add_enum("clip", kClip)
120+
.set_default(kClip)
121+
.describe("Specify how out-of-bound indices behave. Default is \"clip\"."
122+
" \"clip\" means clip to the range. So, if all indices mentioned are too large,"
123+
" they are replaced by the index that addresses the last element along an axis. "
124+
" \"wrap\" means to wrap around.");
115125
}
116126
};
117127

@@ -1108,7 +1118,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
11081118
const std::vector<NDArray>& outputs);
11091119

11101120
/*! \brief index element from array along axes */
1111-
template<int ndim>
1121+
template<int ndim, bool clip = true>
11121122
struct pick {
11131123
template<typename DType, typename IType>
11141124
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
@@ -1117,15 +1127,20 @@ struct pick {
11171127
mshadow::Shape<ndim> sshape) {
11181128
using namespace broadcast;
11191129
int j = static_cast<int>(idx[i]);
1120-
if (j < 0) j = 0;
1121-
else if (j >= M) j = M-1;
1130+
if (clip) {
1131+
if (j <= 0) j = 0;
1132+
else if (j >= M) j = M - 1;
1133+
} else {
1134+
j = j % M;
1135+
j += (j < 0) ? M : 0;
1136+
}
11221137
j = ravel(unravel(i, sshape), bshape) + j*stride;
11231138
out[i] = a[j];
11241139
}
11251140
};
11261141

11271142
/*! \brief index element from array along axes */
1128-
template<int ndim>
1143+
template<int ndim, bool clip = true>
11291144
struct pick_grad {
11301145
template<typename DType, typename IType>
11311146
MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
@@ -1134,8 +1149,13 @@ struct pick_grad {
11341149
mshadow::Shape<ndim> sshape) {
11351150
using namespace broadcast;
11361151
int j = static_cast<int>(idx[i]);
1137-
if (j < 0) j = 0;
1138-
else if (j >= M) j = M-1;
1152+
if (clip) {
1153+
if (j <= 0) j = 0;
1154+
else if (j >= M) j = M - 1;
1155+
} else {
1156+
j = j % M;
1157+
j += (j < 0) ? M : 0;
1158+
}
11391159
j = ravel(unravel(i, sshape), bshape) + j*stride;
11401160
igrad[j] += ograd[i];
11411161
}
@@ -1195,15 +1215,28 @@ void PickOpForward(const nnvm::NodeAttrs& attrs,
11951215

11961216
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output type
11971217
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index type
1198-
if (trailing == 1) {
1199-
Kernel<pick<2>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1200-
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1201-
M, 1, Shape2(leading, M), Shape2(leading, 1));
1218+
if (param.mode == kWrap) {
1219+
if (trailing == 1) {
1220+
Kernel<pick<2, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1221+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1222+
M, 1, Shape2(leading, M), Shape2(leading, 1));
1223+
} else {
1224+
Kernel<pick<3, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1225+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1226+
M, trailing, Shape3(leading, M, trailing),
1227+
Shape3(leading, 1, trailing));
1228+
}
12021229
} else {
1203-
Kernel<pick<3>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1204-
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1205-
M, trailing, Shape3(leading, M, trailing),
1206-
Shape3(leading, 1, trailing));
1230+
if (trailing == 1) {
1231+
Kernel<pick<2, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1232+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1233+
M, 1, Shape2(leading, M), Shape2(leading, 1));
1234+
} else {
1235+
Kernel<pick<3, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1236+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1237+
M, trailing, Shape3(leading, M, trailing),
1238+
Shape3(leading, 1, trailing));
1239+
}
12071240
}
12081241
});
12091242
});
@@ -1230,15 +1263,28 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
12301263
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output type
12311264
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index type
12321265
if (req[0] != kAddTo) outputs[0].FlatTo1D<xpu, DType>(s) = 0;
1233-
if (trailing == 1) {
1234-
Kernel<pick_grad<2>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
1235-
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1236-
M, 1, Shape2(leading, M), Shape2(leading, 1));
1266+
if (param.mode == kWrap) {
1267+
if (trailing == 1) {
1268+
Kernel<pick_grad<2, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
1269+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1270+
M, 1, Shape2(leading, M), Shape2(leading, 1));
1271+
} else {
1272+
Kernel<pick_grad<3, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
1273+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1274+
M, trailing, Shape3(leading, M, trailing),
1275+
Shape3(leading, 1, trailing));
1276+
}
12371277
} else {
1238-
Kernel<pick_grad<3>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
1239-
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1240-
M, trailing, Shape3(leading, M, trailing),
1241-
Shape3(leading, 1, trailing));
1278+
if (trailing == 1) {
1279+
Kernel<pick_grad<2, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
1280+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1281+
M, 1, Shape2(leading, M), Shape2(leading, 1));
1282+
} else {
1283+
Kernel<pick_grad<3, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
1284+
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
1285+
M, trailing, Shape3(leading, M, trailing),
1286+
Shape3(leading, 1, trailing));
1287+
}
12421288
}
12431289
});
12441290
});

src/operator/tensor/broadcast_reduce_op_index.cc

+9-1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ Examples::
133133
// picks elements with specified indices along axis 1
134134
pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]
135135
136+
y = [[ 1.],
137+
[ 0.],
138+
[ 2.]]
139+
140+
// picks elements with specified indices along axis 1 using 'wrap' mode
141+
// to place indicies that would normally be out of bounds
142+
pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1., 4., 5.]
143+
136144
y = [[ 1.],
137145
[ 0.],
138146
[ 2.]]
@@ -165,7 +173,7 @@ Examples::
165173
})
166174
.add_argument("data", "NDArray-or-Symbol", "The input array")
167175
.add_argument("index", "NDArray-or-Symbol", "The index array")
168-
.add_arguments(ReduceAxisParam::__FIELDS__());
176+
.add_arguments(PickParam::__FIELDS__());
169177

170178

171179
NNVM_REGISTER_OP(_backward_pick)

tests/python/unittest/test_operator.py

+34-26
Original file line numberDiff line numberDiff line change
@@ -4523,33 +4523,41 @@ def test_log_softmax():
45234523
def test_pick():
45244524
def test_pick_helper(index_type=np.int32):
45254525
for _ in range(100):
4526-
ndim = np.random.randint(1, 5)
4527-
bshape = np.random.randint(1, 10, size=ndim)
4528-
axis = np.random.randint(0, ndim)
4529-
sshape = bshape.copy()
4530-
sshape[axis] = 1
4531-
data = np.random.uniform(-1, 1, size=bshape)
4532-
index = np.random.randint(0, bshape[axis], size=sshape)
4533-
exp = []
4534-
for i in range(ndim):
4535-
if i == axis:
4536-
exp.append(index)
4526+
for mode in ['clip', 'wrap']:
4527+
ndim = np.random.randint(1, 5)
4528+
bshape = np.random.randint(1, 10, size=ndim)
4529+
axis = np.random.randint(0, ndim)
4530+
sshape = bshape.copy()
4531+
sshape[axis] = 1
4532+
data = np.random.uniform(-1, 1, size=bshape)
4533+
4534+
if mode == 'wrap':
4535+
index = np.random.randint(-2*bshape[axis], 2*bshape[axis], size=sshape)
45374536
else:
4538-
ishape = [1 for _ in range(ndim)]
4539-
ishape[i] = bshape[i]
4540-
exp.append(np.arange(bshape[i]).reshape(ishape))
4541-
expected = data[exp]
4542-
data = mx.nd.array(data, dtype='float32')
4543-
index = mx.nd.array(index, dtype=index_type)
4544-
out = mx.nd.pick(data, index, axis=axis, keepdims=True)
4545-
assert_almost_equal(out.asnumpy(), expected)
4546-
4547-
data_holder = data
4548-
index_holder = index
4549-
data = mx.sym.Variable('data')
4550-
index = mx.sym.Variable('index')
4551-
sym = mx.sym.pick(data, index, axis=axis, keepdims=True)
4552-
check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data'])
4537+
index = np.random.randint(0, bshape[axis], size=sshape)
4538+
exp = []
4539+
for i in range(ndim):
4540+
if i == axis:
4541+
if mode == 'wrap':
4542+
exp.append(index % bshape[axis])
4543+
else:
4544+
exp.append(index)
4545+
else:
4546+
ishape = [1 for _ in range(ndim)]
4547+
ishape[i] = bshape[i]
4548+
exp.append(np.arange(bshape[i]).reshape(ishape))
4549+
expected = data[exp]
4550+
data = mx.nd.array(data, dtype='float32')
4551+
index = mx.nd.array(index, dtype=index_type)
4552+
out = mx.nd.pick(data, index, axis=axis, keepdims=True, mode=mode)
4553+
assert_almost_equal(out.asnumpy(), expected)
4554+
4555+
data_holder = data
4556+
index_holder = index
4557+
data = mx.sym.Variable('data')
4558+
index = mx.sym.Variable('index')
4559+
sym = mx.sym.pick(data, index, axis=axis, keepdims=True, mode=mode)
4560+
check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data'])
45534561

45544562
test_pick_helper(np.int32)
45554563
test_pick_helper(np.float32)

0 commit comments

Comments
 (0)