-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ReorderLoDTensorByRank #6789
Add ReorderLoDTensorByRank #6789
Conversation
fe473f9
to
70236b6
Compare
AddInput("RankTable", | ||
"(LoDRankTable) the rank table that input need follow"); | ||
AddOutput("Out", "(LoDTensor) reordered lod tensor"); | ||
AddComment(R"DOC(ReorderLoDTensorLoDRankTable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReorderLoDTensorLoDRankTable --> ReorderLoDTensorByRankTable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of | ||
X will be the first sequence of Output. | ||
|
||
NOTE: The RankTable does not need to be calculated by X. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please give a specific example to make things easier to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
namespace paddle { | ||
namespace operators { | ||
|
||
class ReorderLoDTensorProtoMaker : public framework::OpProtoAndCheckerMaker { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to change to ReorderLoDTensorByRankTableProtoMaker ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
|
||
def reorder_lod_tensor_by_rank(x, rank_table): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refine the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could use python to generate the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
size_t out_offset = 0; | ||
out->mutable_lod()->clear(); | ||
for (auto &item : rank_table.items()) { | ||
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[item.index], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some assert
to make sure absolute_table[item.index]
valid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
out.Resize(x.dims()); | ||
out.mutable_data(x.place(), x.type()); | ||
this->process(dev_ctx, x, rank_table, &out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure the size of rank_table
and size of x' lod
is match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be checked during Tensor::Slice.
|
||
|
||
class TestReorderLoDTensor(unittest.TestCase): | ||
def test_reorder(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add Gradient test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been tested, since IG.sum() == 1 and LoD can be restored.
It is useful to reorder RNN memory block.
70236b6
to
b23982a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great Job, Thx.
It is useful to reorder RNN memory block.