diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 099252571e194..41e4351c37851 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -662,7 +662,7 @@ infer_meta : func : GeneralUnaryGradInferMeta param : [x] - spmd_rule : StridedGradInferSpmdDynamic + spmd_rule : StridedSliceGradInferSpmdDynamic kernel : func : strided_slice_grad no_need_buffer : x diff --git a/paddle/phi/infermeta/spmd_rules/slice.cc b/paddle/phi/infermeta/spmd_rules/slice.cc index cfe6679269b13..1ec057a1734e5 100644 --- a/paddle/phi/infermeta/spmd_rules/slice.cc +++ b/paddle/phi/infermeta/spmd_rules/slice.cc @@ -288,12 +288,12 @@ SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input, return SliceInferSpmdBase(input, axes_bridge); } -SpmdInfo StridedGradInferSpmdDynamic(const DistMetaTensor& input, - const DistMetaTensor& out_grad, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides) { +SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input, + const DistMetaTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides) { std::vector axes_bridge(axes.begin(), axes.end()); return SliceGradInferBase(input, out_grad, axes_bridge); } diff --git a/paddle/phi/infermeta/spmd_rules/slice.h b/paddle/phi/infermeta/spmd_rules/slice.h index 9f892f37b888e..b23697b0bf2d9 100644 --- a/paddle/phi/infermeta/spmd_rules/slice.h +++ b/paddle/phi/infermeta/spmd_rules/slice.h @@ -62,12 +62,12 @@ SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input, const IntArray& ends, const IntArray& strides); -SpmdInfo StridedGradInferSpmdDynamic(const DistMetaTensor& input, - const DistMetaTensor& out_grad, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides); +SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input, + const DistMetaTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides); } // namespace distributed } // namespace phi