Skip to content

Commit

Permalink
[auto parallel] Nd reshard cross mesh (#59777)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Dec 8, 2023
1 parent c45faed commit cc07f54
Show file tree
Hide file tree
Showing 6 changed files with 525 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h"
#include "paddle/phi/core/distributed/store/store_utils.h"

namespace phi {
namespace distributed {
Expand Down Expand Up @@ -233,6 +235,9 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,
real_out_dist_attr.dims_mapping();
real_dims_mapping[i] = out_mesh_axis;
real_out_dist_attr.set_dims_mapping(real_dims_mapping);
if (real_out_dist_attr.is_partial(out_mesh_axis)) {
real_out_dist_attr.clean_partial_dims({out_mesh_axis});
}

// 4.2 Calculate the process_mesh on specific axis
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, out_mesh_axis);
Expand Down Expand Up @@ -266,5 +271,53 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,
}
}

bool CrossNdMeshReshardFunction::IsSuitable(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
RESHARD_SHORTCUT_IF_FALSE(in.dist_attr().process_mesh() !=
out_dist_attr.process_mesh());
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.process_mesh().ndim() > 1);

// check the input and output dims_mapping is not equal
RESHARD_SHORTCUT_IF_FALSE(in.dist_attr() != out_dist_attr);

return true;
}

void CrossNdMeshReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call CrossNdMeshReshardFunction Eval";
const auto& in_dist_attr = in.dist_attr();

DistTensor tmp_result;
TensorDistAttr in_dist_attr_shard = in_dist_attr;
in_dist_attr_shard.set_partial_status(out_dist_attr.partial_status());
in_dist_attr_shard.set_dims_mapping(out_dist_attr.dims_mapping());

int64_t cur_global_rank = GetCurGlobalRank();
if (in_dist_attr.process_mesh().contains(cur_global_rank)) {
SameNdMeshReshardFunction same_nd_reshard_func;
PADDLE_ENFORCE(
same_nd_reshard_func.IsSuitable(in, in_dist_attr_shard),
phi::errors::InvalidArgument(
"Invoke the same nd reshard function is not valid from %s to %s.",
in_dist_attr,
in_dist_attr_shard));
same_nd_reshard_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result);
} else {
SetDistProps(&tmp_result, in.dims(), in_dist_attr_shard);
}

SameStatusReshardFunction same_status_func;
PADDLE_ENFORCE(
same_status_func.IsSuitable(tmp_result, out_dist_attr),
phi::errors::InvalidArgument("Invoke the same status reshard function "
"is not valid from %s to %s.",
tmp_result.dist_attr(),
out_dist_attr));
same_status_func.Eval(dev_ctx, tmp_result, out_dist_attr, out);
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,18 @@ class SameNdMeshReshardFunction final : public ReshardFunction {
std::string Name() override { return "SameNdMeshReshard"; }
};

class CrossNdMeshReshardFunction final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;

std::string Name() override { return "CrossNdMeshReshard"; }
};

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ REGISTER_RESHARD_FUNC(SToSReshardFunction);
REGISTER_RESHARD_FUNC(SToSReshardFunctionCrossMesh);
REGISTER_RESHARD_FUNC(SameStatusReshardFunction);
REGISTER_RESHARD_FUNC(SameNdMeshReshardFunction);
REGISTER_RESHARD_FUNC(CrossNdMeshReshardFunction);

} // namespace distributed
} // namespace phi
Loading

0 comments on commit cc07f54

Please sign in to comment.