Skip to content

Commit

Permalink
refine, remove useless code
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Sep 14, 2023
1 parent 2a2b727 commit 544b039
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,

DenseTensor out_physical_tensor_cur_rank;

std::map<int64_t, int64_t> split_axis_to_mesh_axis =
std::map<int, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(out_dims_mapping);
std::vector<int64_t> coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh);

int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second;

int64_t num_of_process = out_process_mesh.shape()[mesh_axis];
Expand All @@ -65,7 +65,7 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
<< " process participate in.";

std::vector<int64_t> split_num_vec =
BalancedSplit(in.dims()[static_cast<int>(split_axis)], num_of_process);
BalancedSplit(in.dims()[split_axis], num_of_process);
IntArray sections(split_num_vec);

std::vector<DenseTensor> split_out_vec;
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
return comm_context;
}

std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
std::map<int, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping) {
std::map<int64_t, int64_t> split_axis_to_mesh_axis;
std::map<int, int64_t> split_axis_to_mesh_axis;
for (size_t i = 0; i < dims_mapping.size(); ++i) {
if (dims_mapping[i] != -1) {
split_axis_to_mesh_axis.emplace(i, dims_mapping[i]);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
// input vector, return a key-value map of tensor_split_axis and
// process_mesh_split_axis.
// For example, if dims_mapping is [-1, 1, -1, 0], will return {1: 1, 3: 0}.
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
std::map<int, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);

// If given a number, balance split it to multiple pieces.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in,

// Ensure the tensor is balanced split, or we need send/recv rather than
// all_gather
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
flag &= (in.local_dims()[static_cast<int>(split_axis)] * num_of_process ==
in.dims()[static_cast<int>(split_axis)]);
Expand Down Expand Up @@ -74,9 +72,7 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
in.value(),
in_process_ids.size(),
GetMutableTensor(out));
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;

if (split_axis == 0) {
// If the input dist tensor is shard(0), the subsequent split
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
Expand Down Expand Up @@ -52,9 +51,9 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
auto dtype = in.dtype();
const auto& logical_ddim = in.dims();
int64_t nranks = in_process_ids.size();
int64_t in_split_axis =
int in_split_axis =
GetSplitAxisWithDimsMapping(in.dist_attr().dims_mapping()).begin()->first;
int64_t out_split_axis =
int out_split_axis =
GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first;

DenseTensor in_all_to_all = in.value();
Expand All @@ -74,7 +73,7 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
std::vector<int> axis;
axis.emplace_back(out_split_axis);
for (size_t i = 0; i < pre_shape_vec.size(); ++i) {
if (static_cast<int64_t>(i) != out_split_axis) {
if (static_cast<int>(i) != out_split_axis) {
axis.emplace_back(i);
}
}
Expand Down

0 comments on commit 544b039

Please sign in to comment.