-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move dropout to phi #40148
Move dropout to phi #40148
Conversation
Thanks for your contribution! |
… move_dropout_to_phi
… move_dropout_to_phi
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void DropoutGradRawKernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad是不是不需要raw kernel,就一个
template <typename T, typename Context> | ||
void DropoutRawKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
paddle::optional<const DenseTensor&> seed_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed tensor的下面的seed应该用Scalar统一表示?而不是写两个参数?
bool is_test, | ||
const std::string& mode, | ||
DenseTensor* x_grad) { | ||
x_grad->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutable_data -> Alloc
bool fix_seed, | ||
DenseTensor* out, | ||
DenseTensor* mask) { | ||
out->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) { | ||
return KernelSignature( | ||
"dropout", | ||
{"X", "Seed"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要有if分支选择seed
DenseTensor* x_grad) { | ||
auto* grad_x = x_grad; | ||
auto* grad_y = &out_grad; | ||
grad_x->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dev_ctx.Alloc
|
||
#pragma once | ||
|
||
#include "paddle/phi/common/scalar.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scalar没有用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall,细节问题后续PR再完善一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Breaking changes
PR changes
OPs
Describe
move dropout to phi