Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Host] add split #5796

Merged
merged 2 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lite/backends/arm/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
concat.cc
type_trans.cc
box_coder.cc
split.cc
shuffle_channel.cc
activation.cc
yolo_box.cc
Expand Down
1 change: 0 additions & 1 deletion lite/backends/arm/math/funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
#include "lite/backends/arm/math/shuffle_channel.h"
#include "lite/backends/arm/math/slice.h"
#include "lite/backends/arm/math/softmax.h"
#include "lite/backends/arm/math/split.h"
#include "lite/backends/arm/math/split_merge_lod_tenosr.h"
#include "lite/backends/arm/math/yolo_box.h"

Expand Down
1 change: 1 addition & 0 deletions lite/backends/host/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ lite_cc_library(math_host SRCS
beam_search.cc
sequence_padding.cc
slice.cc
split.cc
gpc.cc
pad3d.cc
concat.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/arm/math/split.h"
#include "lite/backends/host/math/split.h"
#include <algorithm>
#include "lite/backends/arm/math/funcs.h"

namespace paddle {
namespace lite {
namespace arm {
namespace host {
namespace math {

template <>
void split_cpy<float>(const float* din, float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);

float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);

vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr;
dout_ptr++;
din_ptr++;
}
}
}

template <typename T>
void split(const T* din,
const std::vector<lite::Tensor*>& dout,
Expand All @@ -61,7 +30,7 @@ void split(const T* din,
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
for (int i = out_dim.size() - 2; i >= 0; --i) {
for (int i = static_cast<int>(out_dim.size()) - 2; i >= 0; --i) {
out_strides[i] = out_strides[i + 1] * out_dim[i];
}

Expand All @@ -85,12 +54,16 @@ template void split(const float* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
template void split(const int* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
template void split(const int64_t* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);

} // namespace math
} // namespace arm
} // namespace host
} // namespace lite
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,21 @@
// limitations under the License.

#pragma once

#include <vector>
#include "lite/core/op_lite.h"

namespace paddle {
namespace lite {
namespace arm {
namespace host {
namespace math {

template <typename T>
void split_cpy(const T* din, T* dout, int num);

template <typename T>
void split(const T* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);

} // namespace math
} // namespace arm
} // namespace host
} // namespace lite
} // namespace paddle
5 changes: 2 additions & 3 deletions lite/kernels/arm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${li
add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm)

add_kernel(pool_compute_arm ARM basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} split_compute_host)
add_kernel(concat_compute_arm ARM basic SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(pad2d_compute_arm ARM basic SRCS pad2d_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(prior_box_compute_arm ARM basic SRCS prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm)
Expand All @@ -48,7 +48,7 @@ add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_ker
add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(grid_sampler_compute_arm ARM basic SRCS grid_sampler_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(rnn_compute_arm ARM extra SRCS rnn_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(rnn_compute_arm ARM extra SRCS rnn_compute.cc DEPS ${lite_kernel_deps} math_host)

## 2.other basic kernels: basic kernels that not used in basic models
add_kernel(activation_extra_compute_arm ARM extrta SRCS activation_extra_compute.cc DEPS ${lite_kernel_deps} math_arm)
Expand Down Expand Up @@ -119,7 +119,6 @@ lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_
lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm)
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra)
lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm)
Expand Down
14 changes: 7 additions & 7 deletions lite/kernels/arm/rnn_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/arm/math/lstm.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/backends/arm/math/split.h"
#include "lite/backends/host/math/split.h"

namespace paddle {
namespace lite {
Expand Down Expand Up @@ -266,9 +266,9 @@ void runLSTMLayer(ARMContext* ctx,
output_tensors[i].Resize(dims);
output_tensors_t.push_back(&output_tensors[i]);
}
lite::arm::math::split(
lite::host::math::split(
gate_value->data<float>(), input_tensors_t, 0, stride1);
lite::arm::math::split(output->data<float>(), output_tensors_t, 0, stride2);
lite::host::math::split(output->data<float>(), output_tensors_t, 0, stride2);
auto sd = output->mutable_data<float>();

if (is_reverse) {
Expand Down Expand Up @@ -426,12 +426,12 @@ void RnnCompute::Run() {
last_h_unbind_t.push_back(&last_h_unbind[i]);
last_c_unbind_t.push_back(&last_c_unbind[i]);
}
lite::arm::math::split(
lite::host::math::split(
pre_state[0]->data<float>(), init_h_unbind_t, 0, stride);
lite::arm::math::split(
lite::host::math::split(
pre_state[1]->data<float>(), init_c_unbind_t, 0, stride);
lite::arm::math::split(state[0]->data<float>(), last_h_unbind_t, 0, stride);
lite::arm::math::split(state[1]->data<float>(), last_c_unbind_t, 0, stride);
lite::host::math::split(state[0]->data<float>(), last_h_unbind_t, 0, stride);
lite::host::math::split(state[1]->data<float>(), last_c_unbind_t, 0, stride);

for (int i = 0; i < num_layers; i++) {
if (i > 0) {
Expand Down
39 changes: 3 additions & 36 deletions lite/kernels/arm/split_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/arm/split_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/kernels/host/split_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

template <typename T, PrecisionType PType>
void SplitCompute<T, PType>::Run() {
auto& param = this->template Param<operators::SplitParam>();
const T* din = param.x->template data<T>();
auto& dout = param.output;
auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
for (int i = in_dim.size() - 2; i >= 0; --i) {
in_strides[i] = in_strides[i + 1] * in_dim[i];
}
for (auto out : dout) {
out->set_lod(param.x->lod());
}
lite::arm::math::split(din, dout, param.axis, in_strides);
}

} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle

using split_float =
paddle::lite::kernels::arm::SplitCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, split_float, def)
REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, SplitFloat, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
Expand All @@ -54,9 +23,7 @@ REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, split_float, def)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();

using split_int64 =
paddle::lite::kernels::arm::SplitCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(split, kARM, kInt64, kNCHW, split_int64, def)
REGISTER_LITE_KERNEL(split, kARM, kInt64, kNCHW, SplitInt64T, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
Expand Down
Loading