-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Fix PHI API inplace output code generation. #59133
Changes from 7 commits
7f5f5bd
ef7c7e6
02732a1
22a1a4d
95eea34
f039370
ab43c20
1c73ab9
a54c212
aae2773
0787ce0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -722,6 +722,108 @@ ReshardApiInputToKernelInput( | |
return paddle::none; | ||
} | ||
|
||
void SetInplaceOutputCorrectDistAttr( | ||
phi::DeviceContext* dev_ctx, | ||
Tensor& tensor, // NOLINT | ||
const phi::distributed::TensorDistAttr& dist_attr, | ||
bool need_reshard) { | ||
auto tensor_in = tensor.impl(); | ||
if (tensor_in) { | ||
phi::distributed::DistTensor* dist_tensor = | ||
static_cast<phi::distributed::DistTensor*>(tensor_in.get()); | ||
if (dist_tensor->initialized()) { | ||
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { | ||
if (need_reshard) { | ||
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output" | ||
<< " to origin dist_attr " | ||
<< ReshardDebugInfo(*dist_tensor, dist_attr); | ||
auto* func = phi::distributed::ChooseProperReshardFunction( | ||
*dist_tensor, dist_attr); | ||
func->Eval(dev_ctx, *dist_tensor, dist_attr, dist_tensor); | ||
} else { | ||
// just set correct SPMD dist_attrs | ||
VLOG(6) << "SetInplaceOutputCorrectDistAttr input " << tensor.name() | ||
<< " set its dist_attr from " << dist_tensor->dist_attr() | ||
<< " to " << dist_attr; | ||
dist_tensor->unsafe_set_dist_attr(dist_attr); | ||
} | ||
} | ||
} else { | ||
VLOG(6) << "SetInplaceOutputCorrectDistAttr has" | ||
<< " uninitialized DistTensor input " << tensor.name() | ||
<< ", just set its dist_attr from " << dist_tensor->dist_attr() | ||
<< " to " << dist_attr; | ||
dist_tensor->unsafe_set_dist_attr(dist_attr); | ||
} | ||
} | ||
} | ||
|
||
void SetInplaceOutputCorrectDistAttr( | ||
phi::DeviceContext* dev_ctx, | ||
Tensor& tensor, // NOLINT | ||
const phi::distributed::ArgDistAttr& dist_attr, | ||
bool need_reshard) { | ||
PADDLE_ENFORCE_EQ( | ||
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr), | ||
true, | ||
phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); | ||
SetInplaceOutputCorrectDistAttr( | ||
dev_ctx, tensor, paddle::get<0>(dist_attr), need_reshard); | ||
} | ||
|
||
void SetInplaceOutputCorrectDistAttr( | ||
phi::DeviceContext* dev_ctx, | ||
std::vector<Tensor>& tensors, // NOLINT | ||
const std::vector<phi::distributed::TensorDistAttr>& dist_attr, | ||
bool need_reshard) { | ||
for (size_t i = 0; i < tensors.size(); i++) { | ||
auto tensor_in = tensors[i].impl(); | ||
if (tensor_in) { | ||
phi::distributed::DistTensor* dist_tensor = | ||
static_cast<phi::distributed::DistTensor*>(tensor_in.get()); | ||
if (dist_tensor->initialized()) { | ||
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr[i])) { | ||
if (need_reshard) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这两个有啥区别吗,ReshardIsNeeded和need_reshard感觉是差不多的名字 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我之后再提个新PR,换个名字 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经修复,thx~ |
||
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output" | ||
<< " to origin dist_attr " | ||
<< ReshardDebugInfo(*dist_tensor, dist_attr[i]); | ||
auto* func = phi::distributed::ChooseProperReshardFunction( | ||
*dist_tensor, dist_attr[i]); | ||
func->Eval(dev_ctx, *dist_tensor, dist_attr[i], dist_tensor); | ||
} else { | ||
// just set correct SPMD dist_attrs | ||
VLOG(6) << "SetInplaceOutputCorrectDistAttr input " | ||
<< tensors[i].name() << " set its dist_attr from " | ||
<< dist_tensor->dist_attr() << " to " << dist_attr[i]; | ||
dist_tensor->unsafe_set_dist_attr(dist_attr[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inplace的情况下,直接丢掉output的dist_attr,它的结果还能保证正确性吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有SPMD rules的API才会走到这个分支,因为inplace output和input共用dist_tensor,前面不能把 |
||
} | ||
} | ||
} else { | ||
VLOG(6) << "SetInplaceOutputCorrectDistAttr has" | ||
<< " uninitialized DistTensor input " << tensors[i].name() | ||
<< ", just set its dist_attr from " << dist_tensor->dist_attr() | ||
<< " to " << dist_attr[i]; | ||
dist_tensor->unsafe_set_dist_attr(dist_attr[i]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
void SetInplaceOutputCorrectDistAttr( | ||
phi::DeviceContext* dev_ctx, | ||
std::vector<Tensor>& tensors, // NOLINT | ||
const phi::distributed::ArgDistAttr& dist_attr, | ||
bool need_reshard) { | ||
PADDLE_ENFORCE_EQ( | ||
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>( | ||
dist_attr), | ||
true, | ||
phi::errors::PreconditionNotMet( | ||
"Arg must be a vector of TensorDistAttr")); | ||
SetInplaceOutputCorrectDistAttr( | ||
dev_ctx, tensors, paddle::get<1>(dist_attr), need_reshard); | ||
} | ||
|
||
void ReshardOutputPartialAxisToReplicated( | ||
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { | ||
if (out_tensor->dist_attr().is_partial()) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个原则上要用PADDLE_GET系列宏,或者用try_catch包裹,建议下个PR再修复一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,我下个PR一起修一下~