-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Enhancements for MXTensor for custom operators #17204
Conversation
@mxnet-label-bot add [pr-awaiting-review] |
@rondogency @mseth10 @wkcn @junrushao1994 for review |
include/mxnet/lib_api.h
Outdated
size_t ID) | ||
: data_ptr(data_ptr), shape(shape), dtype(dtype), version(ID) {} | ||
|
||
void update(void *dptr, MXDType type, size_t ver) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need this function? it doesn't have any checks, only copy pointers. I think we can copy them line by line in lib_api.h and keep MXTensor as simple as possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per our discussion, lets move the the for loop to copy shape and call to setDLTensor inside this function. Change name to "setTensor"
include/mxnet/lib_api.h
Outdated
: data_ptr(data_ptr), shape(shape), dtype(dtype) {} | ||
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype, | ||
size_t ID) | ||
: data_ptr(data_ptr), shape(shape), dtype(dtype), version(ID) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it will be better to unify the naming across all places, like using verID in lib_api.h and here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
include/mxnet/lib_api.h
Outdated
@@ -277,6 +283,14 @@ struct MXTensor { | |||
return size; | |||
} | |||
|
|||
/*! \brief helper function to compare two MXTensors */ | |||
inline bool isSame(const MXTensor &oth) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we override operator==? since we won't support C anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think operator== is confusing. For a tensor object, == usually means value comparison.
In the future, we may add other operators !=, <, >, etc.
It may be better and more consistent with MXNet NDArray API to use ‘isSame’.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comparing object is how c++ is doing for vector and usually for struct, and in NDArray we don't have operators !=, <, > either, so I don't think it is going to be confusing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I agree with @wkcn here, == should compare the values of the tensor not the "state" of the tensor (data_ptr, versionID, etc)
@mseth10 @eric-haibin-lin @haojin2 what do you guys think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For C++ vector container, operator==
compares the values.
#include <iostream>
#include <vector>
using namespace std;
int main() {
vector<int> a{1,2,3};
vector<int> b{1,2,3};
vector<int> c{1,2,4};
cout << (a == b) << endl; // 1
cout << (a == c) << endl; // 0
cout << (a.data() == b.data()) << endl; // 0
return 0;
}
Although (a==b)
is 1, a.data()
is not equal to b.data()
.
include/mxnet/lib_api.h
Outdated
@@ -277,6 +289,14 @@ struct MXTensor { | |||
return size; | |||
} | |||
|
|||
/*! \brief helper function to compare two MXTensors */ | |||
inline bool operator==(const MXTensor &oth) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be more consistent with MXNet NDArray API to use ‘IsSame’, since operator== is confusing between same object and same tensor content.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I did not read the previous review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
writing a response together with @rondogency:
For clarity, here were checking the following items: dtype, version ID, shape, and the pointer address of the data in memory. If two tensors have the same data pointer then they will have the same data values also. So here the two tensors will have the same content and will have the same tensor attributes (shape, type, version ID) also.
I dont think we need to be consistent with NDarray API here, since users writing custom ops wont necessarily be familiar with NDarray API. And here we're trying to make the right API for this use-case (custom operators). Our comparison API also checks the version number, and the NDarray API does not. So we're already diverging from being consistent with NDArray API.
For the record, here is the similar NDarray API:
https://github.com/apache/incubator-mxnet/blob/55e222b8c97f99193f832f785cf210f876632add/include/mxnet/ndarray.h#L212-L217
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank @rondogency and you for the explanation in detail!
- If two tensors have the same data pointer then they will have the same data values also.
It is right. However, if two tensors have the same data values, they may not have the same data pointer. I did a simple test on C++ vector.
#include <iostream>
#include <vector>
using namespace std;
int main() {
vector<int> a{1,2,3};
vector<int> b{1,2,3};
cout << (a == b) << endl; // 1
cout << (a.data() == b.data()) << endl; // 0
return 0;
}
- we don't need to be consistent with NDarray API here.
Agree : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wkcn, if @rondogency agrees ill change it back to 'isSame'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @wkcn on this. isSame
is less confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I agree with changing it back to isSame, thanks for the inspection on c++ vector!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
@mxnet-label-bot update [pr-awaiting-merge] |
@wkcn we're ready to merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@wkcn will you merge this? |
Merged : ) Thank you! |
Description
Enhancements to MXTensor for custom operators. Adds the following features:
Uprevs MX_LIBRARY_VERSION to 2 since MXTensor struct is changing
Removes opCallBkwd_t since we use opCallFCompute for both forward/backward functions (was dead code)
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.