Skip to content

Commit

Permalink
[AutoParallel] Fix PHI API inplace output code generation. (PaddlePad…
Browse files Browse the repository at this point in the history
  • Loading branch information
GhostScreaming authored and SecretXV committed Nov 28, 2023
1 parent b0c6352 commit b059e0f
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 168 deletions.
164 changes: 78 additions & 86 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,95 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
}

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr"));
if (out) {
if (out->impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attr);
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
phi::DDim(), paddle::get<0>(dist_attr));
out->set_impl(dist_t);
}
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
}
return nullptr;
}

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::distributed::DistTensor*> results(out_size);
if (out->size() != out_size) {
// Empty out vector
out->reserve(out_size);
}
for (size_t i = 0; i < out_size; ++i) {
if (out->size() != out_size) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
out->emplace_back();
out->back().set_impl(dist_t);
}
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr"));
return SetKernelDistOutput(out, paddle::get<0>(dist_attr));
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
dist_attr);
auto out_size = dist_attrs.size();
std::vector<phi::distributed::DistTensor*> results(out_size);
// TODO(GhostScreaming): Inplace outputs are initialized, just set their
// dist_attr.
if (out->size() == out_size) {
VLOG(3) << "Outputs are inplace vector Tensors, just set their dist_attrs "
<< "according to InferSPMD output result.";
for (size_t i = 0; i < out_size; ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
results[i]->unsafe_set_dist_attr(dist_attrs[i]);
}
} else {
out->reserve(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
phi::DDim(), dist_attrs[i]);
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
}
return results;
}

// For backward
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
for (auto tmp : out) {
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
tmp->set_impl(dist_t);
}
result.emplace_back(
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
} else {
result.emplace_back(nullptr);
}
}
return result;
}

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Expand Down Expand Up @@ -609,84 +679,6 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
return nullptr;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
for (auto tmp : out) {
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
tmp->set_impl(dist_t);
}
result.emplace_back(
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
} else {
result.emplace_back(nullptr);
}
}
return result;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
dist_attr);
auto out_size = dist_attrs.size();
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attrs[i]);
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out) {
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::distributed::DistTensor*> results(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out) {
std::vector<phi::distributed::DistTensor*> results;
if (out) {
results = std::vector<phi::distributed::DistTensor*>(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
}
return results;
}
void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh) {
Expand Down
24 changes: 8 additions & 16 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,10 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out,
bool set_dist_output_as_tensor_impl,
const phi::distributed::ArgDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

// For backward
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);

Expand All @@ -169,11 +158,14 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out);

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out);
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out,
bool set_dist_output_as_tensor_impl,
const phi::distributed::ArgDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out);
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

// DistTensor need to set initial dist attr after the dims setted, it is
// constructed based dims and current process mesh, beforce calling this
Expand Down
102 changes: 102 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,108 @@ ReshardApiInputToKernelInput(
return paddle::none;
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::TensorDistAttr& dist_attr,
bool use_general_spmd_rule) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->initialized()) {
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
if (use_general_spmd_rule) {
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
<< " to origin dist_attr "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
auto* func = phi::distributed::ChooseProperReshardFunction(
*dist_tensor, dist_attr);
func->Eval(dev_ctx, *dist_tensor, dist_attr, dist_tensor);
} else {
// just set correct SPMD dist_attrs
VLOG(6) << "SetInplaceOutputCorrectDistAttr input " << tensor.name()
<< " set its dist_attr from " << dist_tensor->dist_attr()
<< " to " << dist_attr;
dist_tensor->unsafe_set_dist_attr(dist_attr);
}
}
} else {
VLOG(6) << "SetInplaceOutputCorrectDistAttr has"
<< " uninitialized DistTensor input " << tensor.name()
<< ", just set its dist_attr from " << dist_tensor->dist_attr()
<< " to " << dist_attr;
dist_tensor->unsafe_set_dist_attr(dist_attr);
}
}
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool use_general_spmd_rule) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr"));
SetInplaceOutputCorrectDistAttr(
dev_ctx, tensor, paddle::get<0>(dist_attr), use_general_spmd_rule);
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const std::vector<phi::distributed::TensorDistAttr>& dist_attr,
bool use_general_spmd_rule) {
for (size_t i = 0; i < tensors.size(); i++) {
auto tensor_in = tensors[i].impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->initialized()) {
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr[i])) {
if (use_general_spmd_rule) {
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
<< " to origin dist_attr "
<< ReshardDebugInfo(*dist_tensor, dist_attr[i]);
auto* func = phi::distributed::ChooseProperReshardFunction(
*dist_tensor, dist_attr[i]);
func->Eval(dev_ctx, *dist_tensor, dist_attr[i], dist_tensor);
} else {
// just set correct SPMD dist_attrs
VLOG(6) << "SetInplaceOutputCorrectDistAttr input "
<< tensors[i].name() << " set its dist_attr from "
<< dist_tensor->dist_attr() << " to " << dist_attr[i];
dist_tensor->unsafe_set_dist_attr(dist_attr[i]);
}
}
} else {
VLOG(6) << "SetInplaceOutputCorrectDistAttr has"
<< " uninitialized DistTensor input " << tensors[i].name()
<< ", just set its dist_attr from " << dist_tensor->dist_attr()
<< " to " << dist_attr[i];
dist_tensor->unsafe_set_dist_attr(dist_attr[i]);
}
}
}
}

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool use_general_spmd_rule) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
SetInplaceOutputCorrectDistAttr(
dev_ctx, tensors, paddle::get<1>(dist_attr), use_general_spmd_rule);
}

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
if (out_tensor->dist_attr().is_partial()) {
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ ReshardApiInputToKernelInput(
const paddle::optional<std::vector<Tensor>>& tensors,
const phi::distributed::ArgDistAttr& dist_attr);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::TensorDistAttr& dist_attr,
bool use_general_spmd_rule = true);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
Tensor& tensor, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool use_general_spmd_rule = true);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const std::vector<phi::distributed::TensorDistAttr>& dist_attr,
bool use_general_spmd_rule = true);

void SetInplaceOutputCorrectDistAttr(
phi::DeviceContext* dev_ctx,
std::vector<Tensor>& tensors, // NOLINT
const phi::distributed::ArgDistAttr& dist_attr,
bool use_general_spmd_rule = true);

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);

Expand Down
Loading

0 comments on commit b059e0f

Please sign in to comment.