Skip to content

Commit

Permalink
[Reshard] Implement replicated to split with same placement (#55552)
Browse files Browse the repository at this point in the history
* Implement replicated to split reshard function

* fix link error in clang

* refine split functor

* simplify reshard code
  • Loading branch information
LiYuRio authored Jul 26, 2023
1 parent f5830c0 commit 9f3b5f1
Show file tree
Hide file tree
Showing 15 changed files with 693 additions and 4 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#endif

namespace py = pybind11;

Expand Down Expand Up @@ -107,6 +110,11 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
}

void BindAutoParallel(py::module *m) {
#ifdef PADDLE_WITH_DISTRIBUTE
py::class_<phi::distributed::RToSReshardFunction>(*m, "RToSReshardFunction")
.def(py::init<>());
#endif

py::class_<ProcessMesh>(*m, "ProcessMesh")
.def(py::init<>())
.def(py::init<const std::vector<int64_t> &,
Expand Down
15 changes: 14 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
proto_library(auto_parallel_proto SRCS auto_parallel.proto)

set(DISTRIBUTED_SRCS "")

if(WITH_DISTRIBUTE)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_utils.cc
r_to_s_reshard_function.cc)
endif()

collect_srcs(
core_srcs
SRCS
device_mesh.cc
process_mesh.cc
dist_attr.cc
dist_mapper.cc
dist_tensor.cc)
${DISTRIBUTED_SRCS})
2 changes: 1 addition & 1 deletion paddle/phi/core/distributed/auto_parallel/dist_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DistTensor final
public:
/// \brief Construct a dist tensor and allocate space.
/// \param a The allocator used to allocate space.
/// \param meta The meta data of dense tensor.
/// \param meta The meta data of dist tensor.
DistTensor(Allocator* a,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
Expand Down
100 changes: 100 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/kernel_factory.h"

namespace phi {
namespace distributed {

bool RToSReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();

const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();

flag &= IsDimsMappingReplicated(in_dims_mapping);
flag &= IsDimsMappingShard(out_dims_mapping);

const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();

flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);

return flag;
}

std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
const phi::DeviceContext& dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
const auto& out_process_mesh = out_dist_attr->process_mesh();
const DenseTensor& in_physical_tensor_cur_rank = in.value();

DenseTensor out_physical_tensor_cur_rank;

std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(out_dims_mapping);
std::vector<int64_t> coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh);

int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second;

PADDLE_ENFORCE_LT(
mesh_axis,
out_process_mesh.ndim(),
phi::errors::OutOfRange(
"The mesh axis %lld exceed the size of process mesh %lld.",
mesh_axis,
out_process_mesh.ndim()));

int64_t num_of_process = out_process_mesh.shape()[mesh_axis];
VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis
<< ". Split will use axis " << mesh_axis << " of process_mesh."
<< " There will have " << num_of_process
<< " process participate in.";

// TODO(liyurui): Consider the tensor can not be balanced split,
// for example, the shape of tensor is {6} but want to split it by 4
// process.
IntArray sections(std::vector<int64_t>(
num_of_process, in.dims()[split_axis] / num_of_process));

std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);

VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];

return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
out_dist_attr);
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"

namespace phi {
namespace distributed {

class RToSReshardFunction final : public ReshardFunction {
public:
RToSReshardFunction() = default;
~RToSReshardFunction() = default;

bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;

std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};

} // namespace distributed
} // namespace phi
22 changes: 22 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_function.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"

namespace phi {
namespace distributed {} // namespace distributed
} // namespace phi
45 changes: 45 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>

namespace phi {
class DeviceContext;

namespace distributed {
namespace auto_parallel {
class TensorDistAttr;
} // namespace auto_parallel

class DistTensor;
using auto_parallel::TensorDistAttr;

class ReshardFunction {
public:
ReshardFunction() = default;
virtual ~ReshardFunction() = default;

virtual bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;

virtual std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
};

} // namespace distributed
} // namespace phi
78 changes: 78 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/split_kernel.h"

namespace phi {
namespace distributed {

std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis) {
size_t out_number = sections.size();
std::vector<DenseTensor> result(out_number);

std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;

out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
out_meta.emplace_back(result[i]);
out_meta_ptr.emplace_back(&out_meta.back());
}
SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr);

std::vector<DenseTensor*> outs;
for (size_t i = 0; i < out_number; ++i) {
outs.emplace_back(&result[i]);
}

if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The split in reshard only supported on CPU and GPU for now."));
}

} // namespace distributed
} // namespace phi
32 changes: 32 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <vector>
#include "paddle/phi/common/int_array.h"

namespace phi {
class DeviceContext;
class DenseTensor;

namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis);

} // namespace distributed
} // namespace phi
Loading

0 comments on commit 9f3b5f1

Please sign in to comment.