-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Lod Reset Operator #4747
Add Lod Reset Operator #4747
Conversation
paddle/operators/lod_reset_op.cc
Outdated
AddOutput("Out", "The output tensor of lod_reset operator."); | ||
AddAttr<std::vector<int>>("target_lod_0", | ||
"Target level 0 LoD of " | ||
"lod_reset operator."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, the attributes are usually fixed during training, that is to say, it was passed to the network when creating the graph. I'm not sure whether the attributes can be changed during training in the future. And since the batch size and the sequence length may be varying during training, I think it's better to pass LoD as an input, not attributes. The input can be calculated by others operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.cc
Outdated
LoDResetOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "(LoDTensor)The input tensor of lod_reset operator."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(LoDTensor)The
-> (LoDTensor) The
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.cc
Outdated
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "(LoDTensor)The input tensor of lod_reset operator."); | ||
AddInput("target_lod_in", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_lod_in -> TargetLoD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Updated by following the comments
paddle/operators/lod_reset_op.cc
Outdated
LoDResetOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "(LoDTensor)The input tensor of lod_reset operator."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.cc
Outdated
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "(LoDTensor)The input tensor of lod_reset operator."); | ||
AddInput("target_lod_in", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.h
Outdated
"Target LoD should be an ascending vector."); | ||
} | ||
|
||
out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out->ShareDataWith(*in);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.h
Outdated
d_x->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto in_dims = d_x->dims(); | ||
d_x->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d_x->ShareDataWith(*d_out);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.cc
Outdated
An example: | ||
Given a float LoDTensor X with lod = [[0, 2, 5, 6]] | ||
|
||
[[1, 2], [3, 4, 5], [6]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, the LodTensor X is saved as with shape (6, 1)
in our codes:
[1,
2,
3,
4,
5,
6]
But [[1, 2], [3, 4, 5], [6]]
is easy to understand, I wonder how to add this comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
paddle/operators/lod_reset_op.h
Outdated
std::vector<int> level0; | ||
if (lod_t) { | ||
auto* lod = lod_t->data<int>(); | ||
if (!platform::is_cpu_place(ctx.GetPlace())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (platform::is_gpu_place(ctx.GetPlace())) {
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
paddle/operators/lod_reset_op.cc
Outdated
An example: | ||
Given a float LoDTensor X with lod = [[0, 2, 5, 6]] | ||
|
||
[[1, 2], [3, 4, 5], [6]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
paddle/operators/lod_reset_op.h
Outdated
std::vector<int> level0; | ||
if (lod_t) { | ||
auto* lod = lod_t->data<int>(); | ||
if (!platform::is_cpu_place(ctx.GetPlace())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.h
Outdated
"Target LoD should be an ascending vector."); | ||
} | ||
|
||
out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/lod_reset_op.h
Outdated
d_x->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto in_dims = d_x->dims(); | ||
d_x->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
f7ad5c1
to
e602c70
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Resolve #4719