Skip to content

Commit

Permalink
[XPU] speed up for special case of strided_slice op. (#55166)
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored Jul 6, 2023
1 parent a095118 commit 2ff949d
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 0 deletions.
66 changes: 66 additions & 0 deletions paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/stride_slice_util.h"

namespace phi {

Expand Down Expand Up @@ -77,6 +78,71 @@ void StridedSliceRawGradKernel(const Context& dev_ctx,
strides_in[cur_axe] = strides_[i];
}

if (is_strided_slice_special_case(xshape, starts_in, ends_in, strides_in)) {
PADDLE_ENFORCE_EQ(
x.numel(),
x_grad->numel(),
errors::PreconditionNotMet(
"x.numel() should be equal to x_grad->numel() in special case."));
PADDLE_ENFORCE_EQ(
x.numel(),
out_grad.numel() * 2,
errors::PreconditionNotMet("x.numel() should be equal to "
"out_grad->numel() * 2 in special case."));

/*
* sample input: [1 2 3 4 5]
* starts = [0/1]
* strides = [2]
* sample output: [1 0 2 0 3 0 4 0 5 0] (last value in starts is 0)
* sample output: [0 1 0 2 0 3 0 4 0 5] (last value in starts is 1)
*/
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* x_transpose = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());

// step 1: set all value to 0

// int constant(Context* ctx, T* x, int len, T val)
int r = xpu::constant(
dev_ctx.x_context(), x_transpose, x.numel(), static_cast<XPUType>(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");

/*
* step 2: copy dy to dx:
* if starts from 0: [1 2 3 4 5 0 0 0 0 0]
* if starts from 1: [0 0 0 0 0 1 2 3 4 5]
*/
int offset = 0;
if (starts_in.back() == 1) {
offset = x.numel() / 2;
}
// int copy(Context* ctx, const T* x, T* y, int64_t len)
r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
x_transpose + offset,
x.numel() / 2);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
/*
* step3: transpose, input shape is (2, x.numel/2):
* input:
* [1 2 3 4 5
* 0 0 0 0 0]
* after transpose:
* [1 0
* 2 0
* 3 0
* 4 0
* 5 0]
*/
r = xpu::transpose<XPUType>(dev_ctx.x_context(),
x_transpose,
reinterpret_cast<XPUType*>(x_grad->data<T>()),
{2, x.numel() / 2},
{1, 0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
return;
}

int r = xpu::strided_slice_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
Expand Down
52 changes: 52 additions & 0 deletions paddle/phi/kernels/xpu/stride_slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/xpu/stride_slice_util.h"

namespace phi {

Expand Down Expand Up @@ -99,6 +100,57 @@ void StridedSliceRawKernel(const Context& dev_ctx,
strides_in[cur_axe] = strides_[i];
}

if (is_strided_slice_special_case(xshape, starts_in, ends_in, strides_in)) {
PADDLE_ENFORCE_EQ(
x.numel(),
out->numel() * 2,
errors::PreconditionNotMet(
"x.numel() should be equal to out->numel() * 2 in special case."));
/*
* sample input: [1 2 3 4 5 6 7 8 9 10]
* starts = [0/1]
* strides = [2]
* sample output: [1 3 5 7 9] (last value in starts is 0)
* sample output: [2 4 6 8 10] (last value in starts is 1)
*/
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* x_transpose = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());
/*
* step 1: transpose, input shape is (x.numel/2, 2):
* input:
* [1 2
* 3 4
* 5 6
* 7 8
* 9 10]
* after transpose:
* [1 3 5 7 9
* 2 4 6 8 10]
*/
// int transpose(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& permute)
int r =
xpu::transpose<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
x_transpose,
{x.numel() / 2, 2},
{1, 0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
// step 2: if starts from 0, use "first half" data as result, otherwise use
// "second half".
int offset = 0;
if (starts_in.back() == 1) {
offset = x.numel() / 2;
}
// int copy(Context* ctx, const T* x, T* y, int64_t len)
r = xpu::copy<XPUType>(dev_ctx.x_context(),
x_transpose + offset,
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel() / 2);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}

int r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
Expand Down
54 changes: 54 additions & 0 deletions paddle/phi/kernels/xpu/stride_slice_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

namespace phi {

inline bool is_strided_slice_special_case(const std::vector<int>& xshape,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int>& strides) {
// starts match {0, 0, ..., 0, 0} or {0, 0, ..., 0, 1}
for (size_t i = 0; i < starts.size() - 1; i++) {
if (starts[i] != 0) {
return false;
}
}
if (starts.back() != 0 && starts.back() != 1) {
return false;
}
// xshape match ends
if (xshape != ends) {
return false;
}
// strides match {1, 1, ..., 1, 2}
for (size_t i = 0; i < strides.size() - 1; i++) {
if (strides[i] != 1) {
return false;
}
}
if (strides.back() != 2) {
return false;
}
// last dim of xshape is even number
if (xshape.back() % 2 != 0) {
return false;
}
return true;
}

} // namespace phi
18 changes: 18 additions & 0 deletions test/xpu/test_strided_slice_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,24 @@ def initTestCase(self):
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]

class XPUTestStrideSliceOp_eb_1(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (1, 4, 4096, 128)
self.axes = [0, 1, 2, 3]
self.starts = [0, 0, 0, 0]
self.ends = [1, 4, 4096, 128]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]

class XPUTestStrideSliceOp_eb_2(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (1, 4, 4096, 128)
self.axes = [0, 1, 2, 3]
self.starts = [0, 0, 0, 1]
self.ends = [1, 4, 4096, 128]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]


support_types = get_xpu_op_support_types('strided_slice')
for stype in support_types:
Expand Down

0 comments on commit 2ff949d

Please sign in to comment.