diff --git a/paddle/phi/api/lib/tensor_utils.cc b/paddle/phi/api/lib/tensor_utils.cc index 4d5711ecb4078f..aa9a678f2e10b5 100644 --- a/paddle/phi/api/lib/tensor_utils.cc +++ b/paddle/phi/api/lib/tensor_utils.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/api_registry.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #ifdef PADDLE_WITH_CUDA @@ -126,7 +127,10 @@ PADDLE_API std::shared_ptr reshard( if (input_tensor_impl) { phi::distributed::DistTensor* dist_tensor = static_cast(input_tensor_impl.get()); - if (dist_tensor->dist_attr() != dist_attr) { + if (dist_tensor->dist_attr() != dist_attr && + (phi::distributed::IsCurRankInMesh( + dist_tensor->dist_attr().process_mesh()) || + phi::distributed::IsCurRankInMesh(dist_attr.process_mesh()))) { VLOG(6) << "reshard func, reshard tensor from " << dist_tensor->dist_attr() << " to " << dist_attr; auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,