From 966a743e0b4f90e9a6100d1686b4f3bd9ed8c3be Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Wed, 3 Jan 2024 16:02:28 +0800 Subject: [PATCH] fix reshard dist_attr --- paddle/phi/core/distributed/auto_parallel/dist_tensor.cc | 3 +++ .../distributed/auto_parallel/reshard/reshard_function.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 626b5bdf5e441..fff9af10339a6 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -117,6 +117,9 @@ DistTensor::DistTensor() : value_(std::make_shared()) {} DistTensor::DistTensor(const std::shared_ptr& global_value, const TensorDistAttr& dist_attr) : global_dims_(global_value->dims()), dist_attr_(dist_attr) { + process_mesh_ = dist_attr_.process_mesh(); + placements_ = ToPlacements(dist_attr); + // If the current rank doesn't in process_mesh, we should create an // uninitialized tensor only with tensor_meta. if (IsCurRankInMesh(dist_attr.process_mesh())) { diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc index 9644c0b28e916..99da6feb54eba 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc @@ -52,6 +52,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor, tensor->global_dims_ = dims; tensor->dist_attr_ = dist_attr; + tensor->process_mesh_ = dist_attr.process_mesh(); + tensor->placements_ = ToPlacements(dist_attr); } void ReshardFunction::SetDistProps(DistTensor* tensor, @@ -64,6 +66,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor, str_join(vectorize(tensor->dims())))); tensor->dist_attr_ = dist_attr; + tensor->process_mesh_ = dist_attr.process_mesh(); + tensor->placements_ = ToPlacements(dist_attr); } DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {