Skip to content

Commit

Permalink
concat spmd rule
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Oct 30, 2023
1 parent 3da2ae9 commit a91132f
Show file tree
Hide file tree
Showing 32 changed files with 1,264 additions and 243 deletions.
9 changes: 5 additions & 4 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ using paddle::distributed::auto_parallel::SPMDRuleMap;
using paddle::framework::BlockDesc;
using paddle::framework::OpDesc;
using paddle::framework::VarDesc;
using phi::distributed::ArgDistAttr;
using phi::distributed::ProcessMesh;
using phi::distributed::TensorDistAttr;
using phi::distributed::auto_parallel::Device;
Expand Down Expand Up @@ -143,9 +144,9 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
dist_attr->clear_annotated();
}

static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
static std::pair<std::vector<ArgDistAttr>, std::vector<ArgDistAttr>>
infer_forward(const phi::distributed::SpmdRule &self, const py::args &args);
static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
static std::pair<std::vector<ArgDistAttr>, std::vector<ArgDistAttr>>
infer_backward(const phi::distributed::SpmdRule &self, const py::args &args);

void BindAutoParallel(py::module *m) {
Expand Down Expand Up @@ -703,15 +704,15 @@ static void prepare_ctx(phi::distributed::InferSpmdContext *ctx,
parse_single_pyobject(obj, ctx, i);
}
}
static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
static std::pair<std::vector<ArgDistAttr>, std::vector<ArgDistAttr>>
infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) {
VLOG(6) << "infer_forward ";
phi::distributed::InferSpmdContext ctx;
prepare_ctx(&ctx, args);
return self.InferForward(ctx);
}

static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
static std::pair<std::vector<ArgDistAttr>, std::vector<ArgDistAttr>>
infer_backward(const phi::distributed::SpmdRule &self, const py::args &args) {
VLOG(6) << "infer_backward ";
phi::distributed::InferSpmdContext ctx;
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,14 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, dense_out);
}
auto current_process_mesh = spmd_info.first[0].process_mesh();
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(
spmd_info.first[0]),
true,
phi::errors::PreconditionNotMet(
"arg must be a singe TensorDistAttr"));
auto current_process_mesh =
paddle::get<0>(spmd_info.first[0]).process_mesh();
SetReplicatedDistAttrForOutput(dist_out, current_process_mesh);
return api_output;
}
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,13 @@ phi::distributed::DistTensor* SetKernelDistOutput(
return nullptr;
}

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
// TODO(liuzhenhai): add check dist_attr
return SetKernelDistOutput(
out, paddle::get<phi::distributed::TensorDistAttr>(dist_attr));
}

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
if (out) {
Expand All @@ -568,6 +575,15 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
return nullptr;
}

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
if (out) {
return std::make_shared<phi::distributed::DistTensor>(
phi::DDim(), paddle::get<phi::distributed::TensorDistAttr>(dist_attr));
}
return nullptr;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
Expand Down Expand Up @@ -147,11 +148,17 @@ phi::distributed::DistTensor* SetKernelDistOutput(
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);

Expand Down
129 changes: 125 additions & 4 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,70 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
return nullptr;
}

std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::ArgDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("arg must be a TensorDistAttr"));
const auto& tensor_dist_attr = paddle::get<0>(dist_attr);
return ReshardApiInputToKernelInput(dev_ctx, tensor, tensor_dist_attr);
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const phi::distributed::ArgDistAttr& dist_attrs) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attrs),
true,
phi::errors::PreconditionNotMet(
"arg must be a vector of TensorDistAttr"));
const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs);
return ReshardApiInputToKernelInput(dev_ctx, tensors, tensor_dist_attrs);
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> output;
PADDLE_ENFORCE_EQ(tensors.size(),
dist_attrs.size(),
phi::errors::PreconditionNotMet(
"tensors size and dist_attrs size not equal: %d vs %d",
tensors.size(),
dist_attrs.size()));
for (size_t i = 0; i < dist_attrs.size(); i++) {
output.push_back(
ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i]));
}
return output;
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::ArgDistAttr>& dist_attrs) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> output;
PADDLE_ENFORCE_EQ(tensors.size(),
dist_attrs.size(),
phi::errors::PreconditionNotMet(
"tensors size and dist_attrs size not equal: %d vs %d",
tensors.size(),
dist_attrs.size()));
for (size_t i = 0; i < dist_attrs.size(); i++) {
output.push_back(
ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i]));
}
return output;
}

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
Expand Down Expand Up @@ -679,13 +743,69 @@ ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> result;
result.reserve(tensors.size());
std::vector<std::shared_ptr<phi::distributed::DistTensor>> outputs;
for (size_t i = 0; i < tensors.size(); ++i) {
result.emplace_back(ReshardApiInputToReplicatedKernelInput(
outputs.push_back(ReshardApiInputToReplicatedKernelInput(
dev_ctx, tensors[i], dist_attrs[i]));
}
return result;
return outputs;
}

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::ArgDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("arg must be a TensorDistAttr"));
const auto& tensor_dist_attr = paddle::get<0>(dist_attr);
return ReshardApiInputToReplicatedKernelInput(
dev_ctx, tensor, tensor_dist_attr);
}

paddle::optional<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const paddle::optional<Tensor>& tensor,
const phi::distributed::ArgDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
true,
phi::errors::PreconditionNotMet("arg must be a TensorDistAttr"));
const auto& tensor_dist_attr = paddle::get<0>(dist_attr);
return ReshardApiInputToReplicatedKernelInput(
dev_ctx, tensor, tensor_dist_attr);
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::ArgDistAttr>& dist_attrs) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> outputs;
for (size_t i = 0; i < tensors.size(); ++i) {
outputs.push_back(ReshardApiInputToReplicatedKernelInput(
dev_ctx, tensors[i], dist_attrs[i]));
}
return outputs;
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const phi::distributed::ArgDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
dist_attr),
true,
phi::errors::PreconditionNotMet(
"arg must be a vector of TensorDistAttr"));
const auto& tensor_dist_attrs = paddle::get<1>(dist_attr);
return ReshardApiInputToReplicatedKernelInput(
dev_ctx, tensors, tensor_dist_attrs);
}

void ReshardOutputPartialAxisToReplicated(
Expand Down Expand Up @@ -819,6 +939,7 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
} else {
Expand Down
53 changes: 53 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
Expand Down Expand Up @@ -180,24 +181,76 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::ArgDistAttr& dist_attr);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::ArgDistAttr>& dist_attrs);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const phi::distributed::ArgDistAttr& dist_attrs);

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::ArgDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::ArgDistAttr& dist_attr);

paddle::optional<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const paddle::optional<Tensor>& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

paddle::optional<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const paddle::optional<Tensor>& tensor,
const phi::distributed::ArgDistAttr& dist_attr);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::ArgDistAttr>& dist_attrs);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensor,
const phi::distributed::ArgDistAttr& dist_attr);

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);

Expand Down
Loading

0 comments on commit a91132f

Please sign in to comment.