Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Nov 13, 2023
1 parent b25e06e commit 27ef949
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
spmd_rule : StridedGradInferSpmdDynamic
spmd_rule : StridedSliceGradInferSpmdDynamic
kernel :
func : strided_slice_grad
no_need_buffer : x
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/spmd_rules/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input,
return SliceInferSpmdBase(input, axes_bridge);
}

SpmdInfo StridedGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides) {
SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides) {
std::vector<int64_t> axes_bridge(axes.begin(), axes.end());
return SliceGradInferBase(input, out_grad, axes_bridge);
}
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/spmd_rules/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input,
const IntArray& ends,
const IntArray& strides);

SpmdInfo StridedGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides);
SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides);

} // namespace distributed
} // namespace phi

0 comments on commit 27ef949

Please sign in to comment.