diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index a3c616f5e8b35f..381f77991f72dc 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -51,11 +51,11 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, DenseTensor out_physical_tensor_cur_rank; - std::map split_axis_to_mesh_axis = + std::map split_axis_to_mesh_axis = GetSplitAxisWithDimsMapping(out_dims_mapping); std::vector coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh); - int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + int split_axis = split_axis_to_mesh_axis.begin()->first; int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; @@ -65,7 +65,7 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, << " process participate in."; std::vector split_num_vec = - BalancedSplit(in.dims()[static_cast(split_axis)], num_of_process); + BalancedSplit(in.dims()[split_axis], num_of_process); IntArray sections(split_num_vec); std::vector split_out_vec; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index 1d9677a0a2bc51..2767dfa836394b 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -112,9 +112,9 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, return comm_context; } -std::map GetSplitAxisWithDimsMapping( +std::map GetSplitAxisWithDimsMapping( const std::vector& dims_mapping) { - std::map split_axis_to_mesh_axis; + std::map split_axis_to_mesh_axis; for (size_t i = 0; i < dims_mapping.size(); ++i) { if (dims_mapping[i] != -1) { split_axis_to_mesh_axis.emplace(i, dims_mapping[i]); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index 61a3fbbfcad874..b947c70bb5bc9f 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -40,7 +40,7 @@ std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh); // input vector, return a key-value map of tensor_split_axis and // process_mesh_split_axis. // For example, if dims_mapping is [-1, 1, -1, 0], will return {1: 1, 3: 0}. -std::map GetSplitAxisWithDimsMapping( +std::map GetSplitAxisWithDimsMapping( const std::vector& dims_mapping); // If given a number, balance split it to multiple pieces. diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc index 20adea5ea90bb6..efa5035c495ed0 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc @@ -44,9 +44,7 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in, // Ensure the tensor is balanced split, or we need send/recv rather than // all_gather - std::map split_axis_to_mesh_axis = - GetSplitAxisWithDimsMapping(in_dims_mapping); - int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; int64_t num_of_process = in_process_mesh.size(); flag &= (in.local_dims()[static_cast(split_axis)] * num_of_process == in.dims()[static_cast(split_axis)]); @@ -74,9 +72,7 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, in.value(), in_process_ids.size(), GetMutableTensor(out)); - std::map split_axis_to_mesh_axis = - GetSplitAxisWithDimsMapping(in_dims_mapping); - int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; if (split_axis == 0) { // If the input dist tensor is shard(0), the subsequent split diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc index de60a94f9452a5..45ec2909734463 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc @@ -14,7 +14,6 @@ #include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h" -#include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" @@ -52,9 +51,9 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, auto dtype = in.dtype(); const auto& logical_ddim = in.dims(); int64_t nranks = in_process_ids.size(); - int64_t in_split_axis = + int in_split_axis = GetSplitAxisWithDimsMapping(in.dist_attr().dims_mapping()).begin()->first; - int64_t out_split_axis = + int out_split_axis = GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; DenseTensor in_all_to_all = in.value(); @@ -74,7 +73,7 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, std::vector axis; axis.emplace_back(out_split_axis); for (size_t i = 0; i < pre_shape_vec.size(); ++i) { - if (static_cast(i) != out_split_axis) { + if (static_cast(i) != out_split_axis) { axis.emplace_back(i); } }