Skip to content

Commit

Permalink
[Auto Parallel]: Support optional, inplace input and output for DistT…
Browse files Browse the repository at this point in the history
…ensor. (PaddlePaddle#57092)

* [WIP] Support std::vector<phi::Tensor> input and output for DistTensor.
Concat forward and backward are verified.

* Polish code for new dist tensor implementation.

* Fix bug of DistTensor upgrade. Add support functions for std::vector<Tensor> -> std::vector<Tensor>.

* Add support for DistTensor type of std::vector<phi::Tensor> as input or output of operators.
Following testcases are passed.
1. concat: std::vector<phi::Tensor> -> phi::Tensor
2. unbind: phi::Tensor -> std::vector<phi::Tensor>
3. broadcast_tensors: std::vector<phi::Tensor> -> std::vector<phi::Tensor>

* Polish code. Remove useless comments.

* Add update_loss_scaling in skip_op_lists.

* Polish code.

* [Auto Parallel]: Support paddle::optional<Tensor> and
paddle::optional<std::vector<phi::Tensor>> input and output for DistTensor.

* Polish code.

* Polish code. And support inplace Tensor, std::vector<Tensor>, paddle::optional<Tensor> and paddle::optional<std::vector<Tensor>>

* Polish testcase code. Add testcase for inplace paddle::optional<phi::Tensor>.

* Remove useless codes in testcase code.

* Polish code style.

* Polish code style. And fix problems of testcases.
  • Loading branch information
GhostScreaming authored Sep 13, 2023
1 parent 42adc8c commit 79c4f77
Show file tree
Hide file tree
Showing 10 changed files with 624 additions and 75 deletions.
23 changes: 23 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,28 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::distributed::DistTensor*> results(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out) {
std::vector<phi::distributed::DistTensor*> results;
if (out) {
results = std::vector<phi::distributed::DistTensor*>(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
}
return results;
}

} // namespace experimental
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,11 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out);

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out);

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out);

} // namespace experimental
} // namespace paddle
44 changes: 34 additions & 10 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> out;
for (auto x : input) {
for (auto& x : input) {
const auto& tensor_in = x.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
Expand All @@ -691,22 +691,46 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
dense_tensor.meta().is_contiguous()))) {
out.push_back(
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
continue;
} else {
phi::DenseTensor trans_in_tensor = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
}
phi::DenseTensor trans_in_tensor = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
} else {
out.push_back(nullptr);
}
}
return out;
}

paddle::optional<phi::distributed::DistTensor> PrepareDataForDistTensor(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
return {*PrepareDataForDistTensor(
*input, target_args_def, transform_flag, is_stride_kernel)};
}
return paddle::none;
}

paddle::optional<std::vector<std::shared_ptr<phi::distributed::DistTensor>>>
PrepareDataForDistTensor(const paddle::optional<std::vector<Tensor>>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
return PrepareDataForDistTensor(
*input, target_args_def, transform_flag, is_stride_kernel);
}
return paddle::none;
}

} // namespace experimental
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,17 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
const TransformFlag& transform_flag,
bool is_stride_kernel);

paddle::optional<phi::distributed::DistTensor> PrepareDataForDistTensor(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);

paddle::optional<std::vector<std::shared_ptr<phi::distributed::DistTensor>>>
PrepareDataForDistTensor(const paddle::optional<std::vector<Tensor>>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);

} // namespace experimental
} // namespace paddle
Loading

0 comments on commit 79c4f77

Please sign in to comment.