Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-791] Pick with negative indices #12090

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,4 @@ List of Contributors
* [Istvan Fehervari](https://github.com/ifeherva)
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
90 changes: 68 additions & 22 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
}
};

enum PickOpMode {kWrap, kClip};

struct PickParam : public dmlc::Parameter<PickParam> {
dmlc::optional<int> axis;
int mode;
Expand All @@ -112,6 +114,14 @@ struct PickParam : public dmlc::Parameter<PickParam> {
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If true, the axis where we pick the elements is left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(mode)
.add_enum("wrap", kWrap)
.add_enum("clip", kClip)
.set_default(kClip)
.describe("Specify how out-of-bound indices behave. Default is \"clip\"."
" \"clip\" means clip to the range. So, if all indices mentioned are too large,"
" they are replaced by the index that addresses the last element along an axis. "
" \"wrap\" means to wrap around.");
}
};

Expand Down Expand Up @@ -1108,7 +1118,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs);

/*! \brief index element from array along axes */
template<int ndim>
template<int ndim, bool clip = true>
struct pick {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
Expand All @@ -1117,15 +1127,20 @@ struct pick {
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
int j = static_cast<int>(idx[i]);
if (j < 0) j = 0;
else if (j >= M) j = M-1;
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
} else {
j = j % M;
j += (j < 0) ? M : 0;
}
j = ravel(unravel(i, sshape), bshape) + j*stride;
out[i] = a[j];
}
};

/*! \brief index element from array along axes */
template<int ndim>
template<int ndim, bool clip = true>
struct pick_grad {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
Expand All @@ -1134,8 +1149,13 @@ struct pick_grad {
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
int j = static_cast<int>(idx[i]);
if (j < 0) j = 0;
else if (j >= M) j = M-1;
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
} else {
j = j % M;
j += (j < 0) ? M : 0;
}
j = ravel(unravel(i, sshape), bshape) + j*stride;
igrad[j] += ograd[i];
}
Expand Down Expand Up @@ -1195,15 +1215,28 @@ void PickOpForward(const nnvm::NodeAttrs& attrs,

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index type
if (trailing == 1) {
Kernel<pick<2>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
if (param.mode == kWrap) {
if (trailing == 1) {
Kernel<pick<2, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick<3, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
} else {
Kernel<pick<3>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
if (trailing == 1) {
Kernel<pick<2, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick<3, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
}
});
});
Expand All @@ -1230,15 +1263,28 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index type
if (req[0] != kAddTo) outputs[0].FlatTo1D<xpu, DType>(s) = 0;
if (trailing == 1) {
Kernel<pick_grad<2>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
if (param.mode == kWrap) {
if (trailing == 1) {
Kernel<pick_grad<2, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick_grad<3, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
} else {
Kernel<pick_grad<3>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
if (trailing == 1) {
Kernel<pick_grad<2, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick_grad<3, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
}
});
});
Expand Down
10 changes: 9 additions & 1 deletion src/operator/tensor/broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ Examples::
// picks elements with specified indices along axis 1
pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]

y = [[ 1.],
[ 0.],
[ 2.]]

// picks elements with specified indices along axis 1 using 'wrap' mode
// to place indicies that would normally be out of bounds
pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1., 4., 5.]

y = [[ 1.],
[ 0.],
[ 2.]]
Expand Down Expand Up @@ -165,7 +173,7 @@ Examples::
})
.add_argument("data", "NDArray-or-Symbol", "The input array")
.add_argument("index", "NDArray-or-Symbol", "The index array")
.add_arguments(ReduceAxisParam::__FIELDS__());
.add_arguments(PickParam::__FIELDS__());


NNVM_REGISTER_OP(_backward_pick)
Expand Down
60 changes: 34 additions & 26 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4523,33 +4523,41 @@ def test_log_softmax():
def test_pick():
def test_pick_helper(index_type=np.int32):
for _ in range(100):
ndim = np.random.randint(1, 5)
bshape = np.random.randint(1, 10, size=ndim)
axis = np.random.randint(0, ndim)
sshape = bshape.copy()
sshape[axis] = 1
data = np.random.uniform(-1, 1, size=bshape)
index = np.random.randint(0, bshape[axis], size=sshape)
exp = []
for i in range(ndim):
if i == axis:
exp.append(index)
for mode in ['clip', 'wrap']:
ndim = np.random.randint(1, 5)
bshape = np.random.randint(1, 10, size=ndim)
axis = np.random.randint(0, ndim)
sshape = bshape.copy()
sshape[axis] = 1
data = np.random.uniform(-1, 1, size=bshape)

if mode == 'wrap':
index = np.random.randint(-2*bshape[axis], 2*bshape[axis], size=sshape)
else:
ishape = [1 for _ in range(ndim)]
ishape[i] = bshape[i]
exp.append(np.arange(bshape[i]).reshape(ishape))
expected = data[exp]
data = mx.nd.array(data, dtype='float32')
index = mx.nd.array(index, dtype=index_type)
out = mx.nd.pick(data, index, axis=axis, keepdims=True)
assert_almost_equal(out.asnumpy(), expected)

data_holder = data
index_holder = index
data = mx.sym.Variable('data')
index = mx.sym.Variable('index')
sym = mx.sym.pick(data, index, axis=axis, keepdims=True)
check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data'])
index = np.random.randint(0, bshape[axis], size=sshape)
exp = []
for i in range(ndim):
if i == axis:
if mode == 'wrap':
exp.append(index % bshape[axis])
else:
exp.append(index)
else:
ishape = [1 for _ in range(ndim)]
ishape[i] = bshape[i]
exp.append(np.arange(bshape[i]).reshape(ishape))
expected = data[exp]
data = mx.nd.array(data, dtype='float32')
index = mx.nd.array(index, dtype=index_type)
out = mx.nd.pick(data, index, axis=axis, keepdims=True, mode=mode)
assert_almost_equal(out.asnumpy(), expected)

data_holder = data
index_holder = index
data = mx.sym.Variable('data')
index = mx.sym.Variable('index')
sym = mx.sym.pick(data, index, axis=axis, keepdims=True, mode=mode)
check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data'])

test_pick_helper(np.int32)
test_pick_helper(np.float32)
Expand Down