Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Fix PHI API inplace output code generation. #59133

Merged
merged 11 commits into from
Nov 22, 2023
164 changes: 78 additions & 86 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,95 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
}

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr"));
if (out) {
if (out->impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attr);
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
phi::DDim(), paddle::get<0>(dist_attr));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个原则上要用PADDLE_GET系列宏,或者用try_catch包裹,建议下个PR再修复一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,我下个PR一起修一下~

out->set_impl(dist_t);
}
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
}
return nullptr;
}

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::distributed::DistTensor*> results(out_size);
if (out->size() != out_size) {
// Empty out vector
out->reserve(out_size);
}
for (size_t i = 0; i < out_size; ++i) {
if (out->at(i).impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
out->emplace_back();
out->back().set_impl(dist_t);
}
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr"));
return SetKernelDistOutput(out, paddle::get<0>(dist_attr));
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
dist_attr);
auto out_size = dist_attrs.size();
std::vector<phi::distributed::DistTensor*> results(out_size);
// TODO(GhostScreaming): Inplace outputs are initialized, just set their
// dist_attr.
if (out->size() == out_size) {
VLOG(3) << "Outputs are inplace vector Tensors, just set their dist_attrs "
<< "according to InferSPMD output result.";
for (size_t i = 0; i < out_size; ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
results[i]->unsafe_set_dist_attr(dist_attrs[i]);
}
} else {
out->reserve(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
phi::DDim(), dist_attrs[i]);
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
}
return results;
}

// For backward
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
for (auto tmp : out) {
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
tmp->set_impl(dist_t);
}
result.emplace_back(
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
} else {
result.emplace_back(nullptr);
}
}
return result;
}

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Expand Down Expand Up @@ -609,84 +679,6 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
return nullptr;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
for (auto tmp : out) {
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
tmp->set_impl(dist_t);
}
result.emplace_back(
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
} else {
result.emplace_back(nullptr);
}
}
return result;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
dist_attr);
auto out_size = dist_attrs.size();
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attrs[i]);
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out) {
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::distributed::DistTensor*> results(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out) {
std::vector<phi::distributed::DistTensor*> results;
if (out) {
results = std::vector<phi::distributed::DistTensor*>(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
}
return results;
}
void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh) {
Expand Down
24 changes: 8 additions & 16 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,10 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out,
bool set_dist_output_as_tensor_impl,
const phi::distributed::ArgDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

// For backward
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);

Expand All @@ -169,11 +158,14 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out);

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out);
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out,
bool set_dist_output_as_tensor_impl,
const phi::distributed::ArgDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out);
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

// DistTensor need to set initial dist attr after the dims setted, it is
// constructed based dims and current process mesh, beforce calling this
Expand Down
102 changes: 102 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,108 @@ ReshardApiInputToKernelInput(
return paddle::none;
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::TensorDistAttr& dist_attr,
bool need_reshard) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->initialized()) {
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
if (need_reshard) {
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
<< " to origin dist_attr "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
auto* func = phi::distributed::ChooseProperReshardFunction(
*dist_tensor, dist_attr);
func->Eval(dev_ctx, *dist_tensor, dist_attr, dist_tensor);
} else {
// just set correct SPMD dist_attrs
VLOG(6) << "SetInplaceOutputCorrectDistAttr input " << tensor.name()
<< " set its dist_attr from " << dist_tensor->dist_attr()
<< " to " << dist_attr;
dist_tensor->unsafe_set_dist_attr(dist_attr);
}
}
} else {
VLOG(6) << "SetInplaceOutputCorrectDistAttr has"
<< " uninitialized DistTensor input " << tensor.name()
<< ", just set its dist_attr from " << dist_tensor->dist_attr()
<< " to " << dist_attr;
dist_tensor->unsafe_set_dist_attr(dist_attr);
}
}
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool need_reshard) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr"));
SetInplaceOutputCorrectDistAttr(
dev_ctx, tensor, paddle::get<0>(dist_attr), need_reshard);
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const std::vector<phi::distributed::TensorDistAttr>& dist_attr,
bool need_reshard) {
for (size_t i = 0; i < tensors.size(); i++) {
auto tensor_in = tensors[i].impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->initialized()) {
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr[i])) {
if (need_reshard) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个有啥区别吗,ReshardIsNeeded和need_reshard感觉是差不多的名字

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReshardIsNeeded是输入dist_tensor->dist_attr()和输入dist_attr[i]不一致的时候,需要进行reshard。need_reshard是输入参数,在 PHI API 这一层判断当前 API 有没有 SPMD rules,有的话不需要再对Output进行reshard,因为InferSPMD推导出的Output DistAttr是正确的,执行完kernel得到的Output local tensor也是正确的shape。只需要设置Output的dist_attr即可。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我之后再提个新PR,换个名字

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修复,thx~

VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
<< " to origin dist_attr "
<< ReshardDebugInfo(*dist_tensor, dist_attr[i]);
auto* func = phi::distributed::ChooseProperReshardFunction(
*dist_tensor, dist_attr[i]);
func->Eval(dev_ctx, *dist_tensor, dist_attr[i], dist_tensor);
} else {
// just set correct SPMD dist_attrs
VLOG(6) << "SetInplaceOutputCorrectDistAttr input "
<< tensors[i].name() << " set its dist_attr from "
<< dist_tensor->dist_attr() << " to " << dist_attr[i];
dist_tensor->unsafe_set_dist_attr(dist_attr[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inplace的情况下,直接丢掉output的dist_attr,它的结果还能保证正确性吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有SPMD rules的API才会走到这个分支,因为inplace output和input共用dist_tensor,前面不能把spmd_info的结果给output,否则reshard input会出错。output dist_attr的设置放在最后了。和SetReplicatedDistAttrForOutput的作用类似。

}
}
} else {
VLOG(6) << "SetInplaceOutputCorrectDistAttr has"
<< " uninitialized DistTensor input " << tensors[i].name()
<< ", just set its dist_attr from " << dist_tensor->dist_attr()
<< " to " << dist_attr[i];
dist_tensor->unsafe_set_dist_attr(dist_attr[i]);
}
}
}
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool need_reshard) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
SetInplaceOutputCorrectDistAttr(
dev_ctx, tensors, paddle::get<1>(dist_attr), need_reshard);
}

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
if (out_tensor->dist_attr().is_partial()) {
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ ReshardApiInputToKernelInput(
const paddle::optional<std::vector<Tensor>>& tensors,
const phi::distributed::ArgDistAttr& dist_attr);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::TensorDistAttr& dist_attr,
bool need_reshard = true);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool need_reshard = true);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const std::vector<phi::distributed::TensorDistAttr>& dist_attr,
bool need_reshard = true);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool need_reshard = true);

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);

Expand Down
Loading