-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-791] Pick with negative indices #12090
[MXNET-791] Pick with negative indices #12090
Conversation
Hi, This is my first PR to the project and I'm still learning the ropes. Any and all suggestions would be much welcome. I'm not sure how to handle the case. For simplicity, I handled it with mod arithmetic. But in numpy, you can index in the range [-len, len). Which behaviour should be correct for this issue? Cheers, Per |
@@ -30,6 +30,41 @@ | |||
from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled, assertRaises | |||
import unittest | |||
|
|||
@with_seed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will move this back to its proper place - didn't get a chance as my family arrived home =P
efb0c4f
to
975c9db
Compare
Welcome to MXNet family! |
@@ -131,7 +130,7 @@ Examples:: | |||
pick(x, y=[0,1], 0) = [ 1., 4.] | |||
|
|||
// picks elements with specified indices along axis 1 | |||
pick(x, y=[0,1,0], 1) = [ 1., 4., 5.] | |||
pick(x, y=[0,-1,0], 1) = [ 1., 4., 5.] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider adding extra example for the newly supported case instead of modifying the existing cases.
For this case maybe you can do
pick(x, y=[0,1,0,-1], 1) = [ 1., 4., 5., 4.]
@@ -4417,7 +4417,7 @@ def test_pick_helper(index_type=np.int32): | |||
sshape = bshape.copy() | |||
sshape[axis] = 1 | |||
data = np.random.uniform(-1, 1, size=bshape) | |||
index = np.random.randint(0, bshape[axis], size=sshape) | |||
index = np.random.randint(-1*bshape[axis], bshape[axis], size=sshape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer -bshape[axis]
.
0a4064b
to
a9fb7c2
Compare
Thank you for the welcome and guidance ^^ Cheers!! Per |
@@ -99,6 +99,10 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> { | |||
} | |||
}; | |||
|
|||
namespace pick_ { // to avoid name conflict | |||
enum TakeOpMode {kRaise, kWrap, kClip}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer this to be like
enum PickOpMode {kWrap, kClip};
as we do not have raise
mode at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way you do not need the extra namespace either.
@@ -4417,7 +4417,12 @@ def test_pick_helper(index_type=np.int32): | |||
sshape = bshape.copy() | |||
sshape[axis] = 1 | |||
data = np.random.uniform(-1, 1, size=bshape) | |||
index = np.random.randint(0, bshape[axis], size=sshape) | |||
mode = np.random.choice(a=['clip', 'wrap']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please test both choices for mode
using a for-loop instead of randomized testing for mode
@@ -117,7 +117,7 @@ an output array of shape ``(i0,)`` with:: | |||
output[i] = input[i, indices[i]] | |||
|
|||
By default, if any index mentioned is too large, it is replaced by the index that addresses | |||
the last element along an axis (the `clip` mode). | |||
the last element along an axis (the `clip` mode). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please get rid of the unnecessary extra whitespace on this line.
Will do. I kept the raise there to be consistent with the take implementation. Would you like me to remove raise from take? |
a9fb7c2
to
0326116
Compare
Do you think it would make sense to test 'wrap' with the range -2*-bshape[axis]..2*bshape[axis]? To cover cases where one is wrapping more than the length of the axis in both directions? In my understanding, in numpy indices should generally be [-len, len). So, the current implementation diverges from that. |
@perdasilva Sorry for the late reply, I think adding more coverage is totally okay. Regarding the inconsistency with numpy, does this operator has some equivalent op in numpy? There is one for |
@haojin2 No problem at all ^^ I'll also increase the coverage for the test. |
0326116
to
8b23e19
Compare
@perdasilva So if there's no direct equivalence of this one in numpy then we may not need to match the behavior with numpy, please simply ensure we have enough documentation to document the behavior. |
8b23e19
to
21631b6
Compare
I believe that's all done now. Thanks for everything ^^ |
21631b6
to
eb6baf0
Compare
@haojin2 I believe everything is done as specified. Please merge when you get a chance ^^ |
be8c120
to
ecd90aa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
Thank you @perdasilva for the contributions and @haojin2 for reviewing the code.
I see dmlc-core is being updated with this PR. Is it due to recent revert @haojin2 ? How do we proceed here
.add_enum("wrap", kWrap) | ||
.add_enum("clip", kClip) | ||
.set_default(kClip) | ||
.describe("Specify how out-of-bound indices bahave. Default is \"clip\"." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: bahave-> behave
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Thank you!
Please rebase with the latest master, then do a |
c106ef2
to
05e2915
Compare
@haojin2 @sandeep-krishnamurthy I've fixed the typo and done as hao jin asked. Thanks for the review guys ^^ |
Hmm those dmlc changes seem to have persisted...strange |
05e2915
to
63bda11
Compare
Ok - Just did a soft reset, updated the submodule, then re-applied my commits. Seems to have fixed the issue =) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @perdasilva
LGTM.
@haojin2 - This PR is ready for merge after your approval. |
@sandeep-krishnamurthy Ready for merge. |
This reverts commit 633bef3.
This reverts commit 633bef3.
* Adds Per G. da Silva to CONTRIBUTORS.md * Updates pick operation to also handle negative indices
Description
Updates pick operator to handle negative indices.
Implements #12009
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Updates pick and pick_grad in broadcast_reduce_op to handle negative indices
Updates pick to include a mode parameter ('pick' or 'wrap') to specify whether out of bounds indices should either be clipped or wrapped.
Comments
Making change as it was tagged with
Call for Contribution