Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add backward inferspmd rule for reduction #58149

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@
infer_meta :
func : UnchangedInferMeta
param: [x]
spmd_rule : ReductionGradInferSpmd
kernel :
func : mean_grad
backward : mean_double_grad
Expand Down Expand Up @@ -702,6 +703,7 @@
infer_meta :
func : UnchangedInferMeta
param : [x]
spmd_rule : ReductionGradInferSpmd
kernel :
func : sum_grad
composite : sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
spmd_rule : ReductionMeanInferSpmdDynamic
kernel :
func : mean
backward : mean_grad
Expand Down Expand Up @@ -1015,6 +1016,7 @@
output : Tensor(out)
infer_meta :
func : SumInferMeta
spmd_rule : ReductionSumInferSpmdDynamic
kernel :
func : sum
data_type : x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,
// 1.3 Calculate the input one dim dist attr
TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims()));
in_one_dim_dist_attr.set_process_mesh(sub_mesh);
in_one_dim_dist_attr.set_partial_status(std::vector<int64_t>{0});
in_one_dim_dist_attr.set_partial_status(std::vector<int64_t>{0},
kv.second);

// 1.4 Calculate the output one dim dist attr
TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {
namespace distributed {
Expand Down Expand Up @@ -50,9 +52,18 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx,
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
const auto& in_partial_status = in_dist_attr.partial_status();
auto in_reduce_type = in_partial_status.at(0);
bool reduce_mean = false;
auto dtype = in.dtype();

int64_t reduce_type = static_cast<int64_t>(in_partial_status.at(0));
if (in_reduce_type == ReduceType::kRedAvg) {
in_reduce_type = ReduceType::kRedSum;
reduce_mean = true;
}
int64_t reduce_type = static_cast<int64_t>(in_reduce_type);
VLOG(3) << "Transfer from partial to replicated status with reduce type "
<< reduce_type;

RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
AllReduce,
dtype,
Expand All @@ -61,6 +72,24 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx,
reduce_type,
GetMutableTensor(out));

if (reduce_mean) {
VLOG(3) << "Do reduce mean after all reduce sum";
DenseTensor tensor_of_num_process;
IntArray shape({1});
RESHARD_FUNCTOR(dev_ctx,
Full,
in.dtype(),
shape,
static_cast<int64_t>(in_process_ids.size()),
&tensor_of_num_process);
RESHARD_FUNCTOR(dev_ctx,
Divide,
dtype,
out->value(),
tensor_of_num_process,
GetMutableTensor(out));
}

SetDistProps(out, in.dims(), out_dist_attr);
}

Expand Down
88 changes: 79 additions & 9 deletions paddle/phi/infermeta/spmd_rules/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@ using phi::distributed::auto_parallel::str_join;
////////////////// Utils Functions //////////////////
std::string GetOutputNotation(int input_ndim,
const std::string& input_axes,
std::vector<int> reduce_dims,
std::vector<int64_t> reduce_dims,
bool keep_dim) {
// if input_axes is empty means reduce all
if (reduce_dims.empty()) {
for (int i = 0; i < input_ndim; ++i) {
reduce_dims.emplace_back(i);
}
}

// convert the negative dim value to normal dim value
for (auto& reduce_dim : reduce_dims) {
if (reduce_dim < 0) {
Expand All @@ -40,7 +47,7 @@ std::string GetOutputNotation(int input_ndim,

std::string output_axes = "";
for (int i = 0; i < input_ndim; i++) {
std::vector<int>::iterator iter =
std::vector<int64_t>::iterator iter =
std::find(reduce_dims.begin(), reduce_dims.end(), i);
if (iter != reduce_dims.end()) {
// if i is reduce dim, the corresponding input axis
Expand All @@ -58,9 +65,10 @@ std::string GetOutputNotation(int input_ndim,
return output_axes;
}

SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
const std::vector<int>& axis,
bool keep_dim) {
SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
int reduce_type) {
// Step0: Verify input args based on reduction logic
auto x_shape = phi::vectorize(x.dims());
int x_ndim = x_shape.size();
Expand Down Expand Up @@ -102,8 +110,8 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
// Step3.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
out_dist_attr.set_partial_status(
partial_on_dims /*, handle reduce_type in future */);
out_dist_attr.set_partial_status(partial_on_dims,
static_cast<ReduceType>(reduce_type));

// Step3.2 handle input tensor partial (TODO)
// If the op is a linear op, i.e. `linearity` is true, it supports
Expand All @@ -116,14 +124,37 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
VLOG(4) << "Input0 shape: [" << str_join(x_shape) << "] "
<< "dims_mapping: [" << str_join(x_dims_mapping) << "]";
VLOG(4) << "Output dims_mapping: [" + str_join(out_dims_mapping) + "] "
<< "partial_on_dims: [" + str_join(partial_on_dims) + "]\n\n";
<< "partial_on_dims: [" + str_join(partial_on_dims)
<< " with reduce_type " << reduce_type << "]\n\n";

return {{x_dist_attr_src}, {out_dist_attr}};
}

SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim) {
return ReductionInferSpmdBase(
x, axis, keep_dim, static_cast<int>(ReduceType::kRedSum));
}

SpmdInfo ReductionMeanInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
bool keep_dim) {
return ReductionInferSpmdBase(
x, axis.GetData(), keep_dim, static_cast<int>(ReduceType::kRedAvg));
}

SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim) {
return ReductionInferSpmdBase(
x, axis.GetData(), keep_dim, static_cast<int>(ReduceType::kRedSum));
}

SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int>& axis,
const std::vector<int64_t>& axis,
bool keep_dim) {
// Step0: Verify input args based on reduction logic
auto x_shape = phi::vectorize(x.dims());
Expand Down Expand Up @@ -174,5 +205,44 @@ SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x,
return {{x_dist_attr_dst}, {out_dist_attr_src}};
}

SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
TensorDistAttr x_dist_attr = out_grad.dist_attr();
TensorDistAttr x_grad_dist_attr = out_grad.dist_attr();

std::vector<int64_t> x_dim = phi::vectorize(x.dims());
std::vector<int64_t> out_grad_dim = phi::vectorize(out_grad.dims());

if (x_dim.size() != out_grad_dim.size()) {
auto dims_mapping = x_dist_attr.dims_mapping();
auto axis_value = axis.GetData();

for (size_t i = 0; i < axis_value.size(); ++i) {
if (axis_value[i] < 0) {
axis_value[i] += x_dim.size();
}
}
std::sort(axis_value.begin(), axis_value.end());

// if the input_axes is empty means to reduce all
if (axis_value.empty()) {
for (size_t i = 0; i < x_dim.size(); ++i) {
axis_value.emplace_back(i);
}
}

for (const auto& axis : axis_value) {
dims_mapping.insert(dims_mapping.begin() + axis, -1);
}
x_dist_attr.set_dims_mapping(dims_mapping);
x_grad_dist_attr.set_dims_mapping(dims_mapping);
}

return {{x_dist_attr, out_grad.dist_attr()}, {x_grad_dist_attr}};
}

} // namespace distributed
} // namespace phi
24 changes: 22 additions & 2 deletions paddle/phi/infermeta/spmd_rules/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,40 @@ limitations under the License. */

#include <vector>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {

SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
const std::vector<int>& axis,
const std::vector<int64_t>& axis,
bool keep_dim);

// This infer spmd function only use in dynamic mode for it uses
// IntArray as parameter. The IntArray may contain vector of tensor
// which is not support in static mode. So we separate these two and
// use dynamic infer_spmd invoke static infer_spmd function.
SpmdInfo ReductionMeanInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
bool keep_dim);

SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim);

SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int>& axis,
const std::vector<int64_t>& axis,
bool keep_dim);

SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
const IntArray& axis,
bool keep_dim,
bool reduce_all);

} // namespace distributed
} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/spmd_rules/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TensorDistAttr CopyTensorDistAttrForOutput(
TensorDistAttr new_dist_attr = TensorDistAttr();
new_dist_attr.set_process_mesh(src_dist_attr.process_mesh());
new_dist_attr.set_batch_dim(src_dist_attr.batch_dim());
new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims());
// new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims());
// new_dist_attr.set_annotated(false); TODO unset field is false by default.
new_dist_attr.clean_partial_status(); // in partial-stage I, partial is allow
// to propagate
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ PD_REGISTER_KERNEL(divide_grad,
phi::DivideGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
phi::dtype::complex<float>,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/cpu/elementwise_divide_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ PD_REGISTER_KERNEL(divide,
phi::DivideKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
complex64,
Expand Down
16 changes: 13 additions & 3 deletions paddle/phi/kernels/elementwise_divide_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,24 @@ void DivideKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void Divide(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* dense_out) {
MetaTensor meta_out(dense_out);
ElementwiseInferMeta(x, y, &meta_out);
if (x.initialized()) {
DivideKernel<T, Context>(dev_ctx, x, y, dense_out);
}
}

template <typename T, typename Context>
DenseTensor Divide(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
DivideKernel<T, Context>(dev_ctx, x, y, &dense_out);
Divide<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ PD_REGISTER_KERNEL(divide_grad,
phi::dtype::float16,
phi::dtype::bfloat16,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
phi::dtype::complex<float>,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ PD_REGISTER_KERNEL(sum_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/kps/elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ PD_REGISTER_KERNEL(divide,
phi::DivideKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
float16,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/kps/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ PD_REGISTER_KERNEL(sum_raw,
double,
float16,
bfloat16,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/legacy/kps/elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ PD_REGISTER_KERNEL(divide_raw,
phi::DivideRawKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
float16,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/reduce_sum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ PD_REGISTER_KERNEL(sum,
int16_t,
int,
int64_t,
uint8_t,
int8_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
Expand Down
Loading