Skip to content

Commit

Permalink
【auto parallel】llama attention 子图验证 (PaddlePaddle#59491)
Browse files Browse the repository at this point in the history
* auto parallel:llma attention and mlp

* llama mlp、attention dp + mp

* remove log

* skip test

* polish

* polish

* polish
  • Loading branch information
liuzhenhai93 authored and SigureMo committed Dec 5, 2023
1 parent 097241d commit a3be97c
Show file tree
Hide file tree
Showing 13 changed files with 672 additions and 11 deletions.
4 changes: 3 additions & 1 deletion paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
true,
phi::errors::InvalidArgument(
"The forward input DistTensor's dist attr is empty."));
meta.SetDistAttr(dist_tensor->dist_attr());
auto dist_attr = dist_tensor->dist_attr();
dist_attr.clean_partial_status();
meta.SetDistAttr(dist_attr);
meta.SetDistTensorGlobalDims(dist_tensor->dims());
SetIsRunAutoParallel(true);
} else {
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/eager/grad_tensor_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id,
std::static_pointer_cast<phi::DenseTensor>(init_grad.impl());
auto dist_t =
static_cast<phi::distributed::DistTensor*>(t.impl().get());
auto dist_attr = dist_t->dist_attr();
dist_attr.clean_partial_status();
init_grad.set_impl(std::make_shared<phi::distributed::DistTensor>(
global_dense_t, dist_t->dist_attr()));
global_dense_t, dist_attr));
buffer_[slot_id][rank] = init_grad;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@
for (size_t i = 0; i < shape.GetData().size(); i++) {
auto& out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]);
if (out_dist_attr.dims_mapping()[i] >= 0) {
int64_t mesh_dim = out_dist_attr.process_mesh().shape()[i];
int64_t dim = out_dist_attr.dims_mapping()[i];
int64_t mesh_dim = out_dist_attr.process_mesh().shape()[dim];
// TODO: Support aliquant condition.
PADDLE_ENFORCE_EQ(shape.GetData()[i] % mesh_dim,
0,
Expand Down
8 changes: 5 additions & 3 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dims_mapping_ = dims_mapping;
// dynamic_dims_ and dims_mapping may be not consistent
if (dynamic_dims_.empty() || dims_mapping.empty()) {
set_default_dynamic_dims(dims_mapping);
}
}

void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
Expand Down Expand Up @@ -138,9 +142,7 @@ void TensorDistAttr::set_default_dims_mapping(

void TensorDistAttr::set_default_dynamic_dims(
const std::vector<int64_t>& tensor_shape) {
if (!tensor_shape.empty()) {
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}

void TensorDistAttr::mark_annotated(const std::string& name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
namespace phi {
namespace distributed {

using phi::distributed::auto_parallel::str_join;

std::shared_ptr<DistTensor> ReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
Expand All @@ -44,7 +46,9 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,
PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)),
true,
phi::errors::InvalidArgument(
"The input dist_attr and dims are improper."));
"The input dist_attr [%s] and dims [%s] are improper.",
dist_attr.to_string(),
str_join(vectorize(dims))));

tensor->dims_ = dims;
tensor->dist_attr_ = dist_attr;
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/spmd_rules/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x,
auto dims_mapping = x_dist_attr.dims_mapping();
dims_mapping.erase(dims_mapping.begin(), dims_mapping.begin() + diff);
x_dist_attr.set_dims_mapping(dims_mapping);
x_dist_attr.set_default_dynamic_dims(dims_mapping);
x_grad_dist_attr.set_dims_mapping(dims_mapping);
x_grad_dist_attr.set_default_dynamic_dims(dims_mapping);
for (int64_t i = 0; i < diff; ++i) {
if (out_grad.dist_attr().dims_mapping()[i] != -1) {
x_grad_dist_attr.set_partial_status(
Expand All @@ -375,7 +377,9 @@ SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x,
auto dims_mapping = y_dist_attr.dims_mapping();
dims_mapping.erase(dims_mapping.begin(), dims_mapping.begin() + diff);
y_dist_attr.set_dims_mapping(dims_mapping);
y_dist_attr.set_default_dynamic_dims(dims_mapping);
y_grad_dist_attr.set_dims_mapping(dims_mapping);
y_grad_dist_attr.set_default_dynamic_dims(dims_mapping);
for (int64_t i = 0; i < diff; ++i) {
if (out_grad.dist_attr().dims_mapping()[i] != -1) {
y_grad_dist_attr.set_partial_status(
Expand Down
61 changes: 57 additions & 4 deletions paddle/phi/infermeta/spmd_rules/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,43 @@ static bool DistAttrsAreBasicallyEqual(
in_dist_attr.partial_status() == out_dist_attr.partial_status());
}

TensorDistAttr ReduceGradBroadCastDims(const TensorDistAttr& input,
const ArgDistAttr& grad) {
auto& grad_in = PADDLE_GET_CONST(TensorDistAttr, grad);
auto grad_dim = grad_in.dims_mapping().size();
auto input_dim = input.dims_mapping().size();
PADDLE_ENFORCE_GE(
grad_dim,
input_dim,
phi::errors::InvalidArgument("grad dim must ge than input dim, but we "
"got grad_dim [%d], input_dim[%d]",
grad_dim,
input_dim));
if (grad_dim == input_dim) {
return grad_in;
}
size_t broadcast_dim = grad_dim - input_dim;
// gather partial status
auto partial_dims = grad_in.partial_dims();
auto& grad_dims_mapping = grad_in.dims_mapping();
auto dims_mapping = input.dims_mapping();
for (size_t i = 0; i < grad_dim; ++i) {
auto mapping = grad_dims_mapping[i];
if (i < broadcast_dim) {
if (mapping >= 0) {
partial_dims.insert(mapping);
}
} else {
dims_mapping[i - broadcast_dim] = mapping;
}
}
auto grad_out = CopyTensorDistAttrForOutput(input);
grad_out.set_dims_mapping(dims_mapping);
std::vector<int64_t> partial_status(partial_dims.begin(), partial_dims.end());
grad_out.set_partial_status(partial_status);
return grad_out;
}

SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out_grad,
Expand Down Expand Up @@ -369,9 +406,13 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_same_fn(
dy_spmd_info.first[0], out_grad, "trans x&y: dy-out_grad");
confirm_dist_attr_same_fn(dy_spmd_info.first[1], x, "trans x&y: dy-x");
auto x_grad =
ReduceGradBroadCastDims(x.dist_attr(), dx_spmd_info.second[0]);
auto y_grad =
ReduceGradBroadCastDims(y.dist_attr(), dy_spmd_info.second[0]);
return {
{dy_spmd_info.first[1], dx_spmd_info.first[0], dx_spmd_info.first[1]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
{x_grad, y_grad}};
} else {
// X'Y: dX = YG', dY = XG
dx_spmd_info =
Expand All @@ -384,9 +425,13 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_same_fn(dy_spmd_info.first[0], x, "trans x: dy-x");
confirm_dist_attr_same_fn(
dy_spmd_info.first[1], out_grad, "trans x: dy-out_grad");
auto x_grad =
ReduceGradBroadCastDims(x.dist_attr(), dx_spmd_info.second[0]);
auto y_grad =
ReduceGradBroadCastDims(y.dist_attr(), dy_spmd_info.second[0]);
return {
{dy_spmd_info.first[0], dx_spmd_info.first[0], dx_spmd_info.first[1]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
{x_grad, y_grad}};
}
} else {
if (trans_y) {
Expand All @@ -401,9 +446,13 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_same_fn(
dy_spmd_info.first[0], out_grad, "trans y: dy-out_grad");
confirm_dist_attr_same_fn(dy_spmd_info.first[1], x, "trans y: dy-x");
auto x_grad =
ReduceGradBroadCastDims(x.dist_attr(), dx_spmd_info.second[0]);
auto y_grad =
ReduceGradBroadCastDims(y.dist_attr(), dy_spmd_info.second[0]);
return {
{dy_spmd_info.first[1], dx_spmd_info.first[1], dx_spmd_info.first[0]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
{x_grad, y_grad}};
} else {
// XY: dX = GY', dY = X'G
dx_spmd_info =
Expand All @@ -415,9 +464,13 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_with_arg_same_fn(dx_spmd_info.first[0],
dy_spmd_info.first[1],
"no trans: dy-out_grad");
auto x_grad =
ReduceGradBroadCastDims(x.dist_attr(), dx_spmd_info.second[0]);
auto y_grad =
ReduceGradBroadCastDims(y.dist_attr(), dy_spmd_info.second[0]);
return {
{dy_spmd_info.first[0], dx_spmd_info.first[1], dx_spmd_info.first[0]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
{x_grad, y_grad}};
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x,
}
x_dist_attr.set_dims_mapping(dims_mapping);
x_grad_dist_attr.set_dims_mapping(dims_mapping);
x_dist_attr.set_default_dynamic_dims(dims_mapping);
x_grad_dist_attr.set_default_dynamic_dims(dims_mapping);
}

return {{x_dist_attr, out_grad_dist_attr}, {x_grad_dist_attr}};
Expand Down
5 changes: 5 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_semi_auto_parallel_basic
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)

py_test_modules(test_semi_auto_parallel_for_llama_subnet MODULES
test_semi_auto_parallel_for_llama_subnet)
set_tests_properties(test_semi_auto_parallel_for_llama_subnet
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)

py_test_modules(test_semi_auto_parallel_softmax_basic MODULES
test_semi_auto_parallel_softmax_basic)
set_tests_properties(test_semi_auto_parallel_softmax_basic
Expand Down
Loading

0 comments on commit a3be97c

Please sign in to comment.