-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add row conv operator #6013
Add row conv operator #6013
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sidgoyal78 I like your code, including the code of the previous PRs. In my opinion, your code is very beautiful and high quality.
paddle/operators/row_conv_op.cc
Outdated
|
||
$$ | ||
out_{i, :} = \sum_{j=i}^{i + context} in_{j,:} \dot W_{i-j, :} | ||
$$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the doc, there are some comments in
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/function/RowConvOp.cpp#L97
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer, i have included now.
paddle/operators/row_conv_op.cc
Outdated
AddInput("Filter", | ||
"(Tensor), the input(Filter) is a learnable parameter. It " | ||
"is a 2-D tensor with shape (future_context x N), where, " | ||
"future_context is the batch size and N is the data dimension."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like your code, I think the name future_context
is good :)
future_context is the batch size
future_context is the future context length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks.
paddle/operators/row_conv_op.cc
Outdated
auto *Out = context.Output<LoDTensor>("Out"); | ||
|
||
Out->mutable_data<T>(context.GetPlace()); | ||
context.ShareLoD("X", "Out"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there is ctx->ShareLoD("X", "Out")
in the InferShape
, and the previous bug for ShareLoD
in InferShape
has been fixed, line 123 can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
cot_seq(k, d) = weights(w, d) * cip_seq(k + w, d); | ||
} else { | ||
cot_seq(k, d) += weights(w, d) * cip_seq(k + w, d); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can use elementwise mul and col-wise sum to remove the for loop
in line 145 and line 147. But the optimization can be done in the future. So in this PR, I think it is ok here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
paddle/operators/row_conv_op.cc
Outdated
void Compute(const framework::ExecutionContext &context) const override { | ||
auto *X = context.Input<LoDTensor>("X"); | ||
auto *Filter = context.Input<Tensor>("Filter"); | ||
auto *Out = context.Output<LoDTensor>("Out"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming style:
X->x
Filer->filter
Out->out
https://google.github.io/styleguide/cppguide.html#Variable_Names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I need to fix for .cu code)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Resolves #5612 , by adding the implementation of the row-convolution operator.
Few notes: