Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice infer spmd #58920

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@
infer_meta :
func : UnchangedInferMeta
param : [input]
spmd_rule: SliceGradInferSpmdDynamic
kernel :
func : slice_grad
composite: slice_grad(input, out_grad, axes, starts, ends, infer_flags, decrease_axis, input_grad)
Expand Down Expand Up @@ -661,6 +662,7 @@
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
spmd_rule : StridedSliceGradInferSpmdDynamic
kernel :
func : strided_slice_grad
no_need_buffer : x
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@
output : Tensor
infer_meta :
func : SliceRawInferMeta
spmd_rule : SliceInferSpmdDynamic
kernel :
func : slice
backward : slice_grad
Expand Down Expand Up @@ -1003,6 +1004,7 @@
output : Tensor
infer_meta :
func : StridedSliceInferMeta
spmd_rule : StridedSliceInferSpmdDynamic
kernel :
func : strided_slice
backward : strided_slice_grad
Expand Down
162 changes: 144 additions & 18 deletions paddle/phi/infermeta/spmd_rules/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ namespace distributed {

using phi::distributed::auto_parallel::str_join;

SpmdInfo SliceInferSpmd(const DistMetaTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input,
const std::vector<int64_t>& axes) {
auto input_shape = phi::vectorize(input.dims());
int input_ndim = input_shape.size();
auto input_dist_attr_src = input.dist_attr();
Expand Down Expand Up @@ -90,13 +86,18 @@ SpmdInfo SliceInferSpmd(const DistMetaTensor& input,
return {{input_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
const DistMetaTensor& output,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
SpmdInfo SliceInferSpmd(const DistMetaTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
return SliceInferSpmdBase(input, axes);
}

SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
const DistMetaTensor& output,
const std::vector<int64_t>& axes) {
auto output_shape = phi::vectorize(output.dims());
int out_ndim = output_shape.size();
auto out_dist_attr = output.dist_attr();
Expand Down Expand Up @@ -133,11 +134,6 @@ SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,

std::string out_axes(input_axes);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
out_axes[axis] = special_axes[i];
}

std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
std::vector<int64_t> out_dims_mapping = output.dist_attr().dims_mapping();
axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping));
Expand Down Expand Up @@ -172,5 +168,135 @@ SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
return {{input_dist_attr}, {out_dist_attr}};
}

SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
const DistMetaTensor& output,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
return SliceInferSpmdReverseBase(input, output, axes);
}

SpmdInfo SliceInferSpmdDynamic(const DistMetaTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
std::vector<int> start_indexes(starts.GetData().begin(),
starts.GetData().end());
std::vector<int> end_indexes(ends.GetData().begin(), ends.GetData().end());
return SliceInferSpmdBase(input, axes);
}

SpmdInfo SliceGradInferBase(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int64_t>& axes) {
auto input_dist_attr = input.dist_attr();
auto out_dist_attr = out_grad.dist_attr();
input_dist_attr = UnShardTensorDims(input_dist_attr, axes);
out_dist_attr = UnShardTensorDims(out_dist_attr, axes);
auto output_shape = phi::vectorize(out_grad.dims());
int out_ndim = output_shape.size();
int out_dims_mapping_size = out_dist_attr.dims_mapping().size();
auto input_shape = phi::vectorize(input.dims());
int input_ndim = input_shape.size();
std::vector<int64_t> input_dims_mapping = input_dist_attr.dims_mapping();

PADDLE_ENFORCE_EQ(
input_ndim,
out_ndim,
phi::errors::InvalidArgument("The Tensor Input's rank [%d] is not equal "
"to the Tensor Output's rank [%d]",
input_ndim,
out_ndim));

PADDLE_ENFORCE_EQ(
out_ndim,
out_dims_mapping_size,
phi::errors::InvalidArgument("The Tensor Output's rank [%d] and Its "
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping_size));

std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string align_axes = alphabet.substr(0, input_ndim);
std::string input_axes = align_axes;
std::string special_axes = alphabet.substr(input_ndim);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
}
std::string out_axes(input_axes);

std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info.emplace_back(
std::make_pair(out_axes, out_dist_attr.dims_mapping()));
axes_sharding_info.emplace_back(
std::make_pair(input_axes, input_dist_attr.dims_mapping()));
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
auto aligned_dim_mapping =
GetDimsMappingForAxes(align_axes, axis_to_dim_map, true);
TensorDistAttr aligned_dist_attr = CopyTensorDistAttrForOutput(out_dist_attr);
input_dist_attr.set_dims_mapping(aligned_dim_mapping);
out_dist_attr.set_dims_mapping(aligned_dim_mapping);
aligned_dist_attr.set_dims_mapping(aligned_dim_mapping);

VLOG(4) << "SliceGradInfer:";

VLOG(4) << "input"
<< " shape: [" << str_join(input_shape) << "] "
<< "src_dims_mapping: [" << str_join(input.dist_attr().dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(input_dist_attr.dims_mapping())
<< "]";

VLOG(4) << "Output Grad"
<< " shape: [" << str_join(output_shape) << "] "
<< "src_dims_mapping: ["
<< str_join(out_grad.dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping())
<< "]";

VLOG(4) << "input Grad"
<< " shape: [" << str_join(output_shape) << "] "
<< "dims_mapping: [" << str_join(aligned_dist_attr.dims_mapping())
<< "] ";

return {{input_dist_attr, out_dist_attr}, {aligned_dist_attr}};
}

SpmdInfo SliceGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
return SliceGradInferBase(input, out_grad, axes);
}

SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides) {
std::vector<int64_t> axes_bridge(axes.begin(), axes.end());
return SliceInferSpmdBase(input, axes_bridge);
}

SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides) {
std::vector<int64_t> axes_bridge(axes.begin(), axes.end());
return SliceGradInferBase(input, out_grad, axes_bridge);
}

} // namespace distributed
} // namespace phi
29 changes: 29 additions & 0 deletions paddle/phi/infermeta/spmd_rules/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include <vector>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

Expand All @@ -40,5 +41,33 @@ SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis);

SpmdInfo SliceInferSpmdDynamic(const DistMetaTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis);

SpmdInfo SliceGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis);

SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides);

SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides);

} // namespace distributed
} // namespace phi
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/spmd_rules/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ TensorDistAttr FromPlacements(
return dst_dist_attr;
}

TensorDistAttr UnShardTensorDims(const TensorDistAttr& dist_attr,
std::vector<int64_t> dims) {
TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr);
std::vector<int64_t> dims_mapping = dist_attr.dims_mapping();
int64_t n_dim = dims_mapping.size();
for (auto dim : dims) {
dim = dim < 0 ? n_dim + dim : dim;
dims_mapping[dim] = kReplicateDim;
}
dst_dist_attr.set_dims_mapping(dims_mapping);
return dst_dist_attr;
}

std::vector<ArgDistAttr> ToArgDistAttr(
const std::vector<TensorDistAttr>& dist_attrs) {
std::vector<ArgDistAttr> items_dist_attrs;
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/spmd_rules/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
// null.
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);

TensorDistAttr UnShardTensorDims(const TensorDistAttr& dist_attr,
std::vector<int64_t> dims);

// Resolute the partial mesh dimension of a output tensor, giving the
// merged sharding specifcation of input tensors and the axis names of output
// tensor. Input are
Expand Down
62 changes: 62 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import paddle
import paddle.distributed as dist

"""
test for concat、slice 、split
"""


class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase):
def __init__(self):
Expand Down Expand Up @@ -44,6 +48,60 @@ def test_concat_forward_reshard(self):
axis=0,
)

def test_slice(self):
shapes = [64, 4, 4]
specs = [None, None, 'x']
inputs, outputs = self.runfunc_and_check(
inputs_shape=shapes,
inputs_specs=specs,
op_func=paddle.slice,
with_backward=True,
axes=[0, 1],
starts=[1, 1],
ends=[3, 3],
)

def test_slice_reshard(self):
shapes = [64, 4, 4]
specs = [None, 'x', None]
inputs, outputs = self.runfunc_and_check(
inputs_shape=shapes,
inputs_specs=specs,
op_func=paddle.slice,
with_backward=True,
axes=[0, 1],
starts=[1, 1],
ends=[3, 3],
)

def test_stride_slice(self):
shapes = [64, 4, 4]
specs = [None, None, 'x']
inputs, outputs = self.runfunc_and_check(
inputs_shape=shapes,
inputs_specs=specs,
op_func=paddle.strided_slice,
with_backward=True,
axes=[0, 1],
starts=[1, 3],
ends=[3, 1],
strides=[1, -1],
)

def test_stride_slice_reshard(self):
shapes = [64, 4, 4]
specs = [None, 'x', None]
inputs, outputs = self.runfunc_and_check(
inputs_shape=shapes,
inputs_specs=specs,
op_func=paddle.strided_slice,
with_backward=True,
axes=[0, 1],
starts=[1, 3],
ends=[3, 1],
strides=[1, -1],
)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
Expand All @@ -53,9 +111,13 @@ def run_test_case(self):
raise ValueError("Only support cpu or gpu backend.")

self.test_concat_forward()
self.test_slice()
self.test_stride_slice()
# all to all is not supported yet for cpu
if self._backend == "gpu":
self.test_concat_forward_reshard()
self.test_slice_reshard()
self.test_stride_slice_reshard()


if __name__ == '__main__':
Expand Down