Skip to content

Commit

Permalink
[HOST] add tril_triu; fix expand_as (#5507)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Feb 20, 2021
1 parent fd8bac9 commit 5f38270
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 29 deletions.
1 change: 1 addition & 0 deletions lite/kernels/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ add_kernel(expand_v2_compute_host Host extra SRCS expand_v2_compute.cc DEPS ${li
add_kernel(strided_slice_compute_host Host extra SRCS strided_slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fill_any_like_compute_host Host extra SRCS fill_any_like_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scatter_nd_add_compute_host Host extra SRCS scatter_nd_add_compute.cc DEPS ${lite_kernel_deps})
add_kernel(tril_triu_compute_host Host extra SRCS tril_triu_compute.cc DEPS ${lite_kernel_deps})

if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
Expand Down
22 changes: 19 additions & 3 deletions lite/kernels/host/expand_as_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ void ExpandAsCompute<T, PType>::Run() {
const T* src = x->template data<T>();
T* dst = out->template mutable_data<T>();

// int dims = expand_times.size();
for (int i = 0; i < target->dims().size(); ++i) {
int times = target->dims()[i] / x->dims()[i];
expand_times.push_back(times);
Expand Down Expand Up @@ -75,12 +74,29 @@ REGISTER_LITE_KERNEL(expand_as, kHost, kFloat, kAny, expand_as_float, def)
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindInput("Target",
.BindInput("target_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.Finalize();

using expand_as_int64 =
paddle::lite::kernels::host::ExpandAsCompute<int64_t, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(expand_as, kHost, kFloat, kAny, expand_as_int64, int64)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindInput("target_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.Finalize();
72 changes: 72 additions & 0 deletions lite/kernels/host/tril_triu_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/host/tril_triu_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <class T>
void TrilTriu(const T* in,
const int64_t diagonal,
const bool lower,
const int64_t h,
const int64_t w,
T* out) {
int64_t size = h * w;
for (int64_t idx = 0; idx < size; idx++) {
const int64_t row = idx / w;
const int64_t col = idx % w;
const bool mask = lower ? (col - row > diagonal) : (col - row < diagonal);
out[idx] = mask ? 0 : in[idx];
}
return;
}

template <class T>
void TrilTriuCompute<T>::Run() {
auto& param = this->template Param<param_t>();
const lite::Tensor* x = param.x;
lite::Tensor* out = param.out;
int64_t diagonal = param.diagonal;
bool lower = param.lower;

const T* x_data = x->template data<T>();
T* out_data = out->template mutable_data<T>();
auto x_dims = x->dims();
int64_t h = x_dims[x_dims.size() - 2];
int64_t w = x_dims[x_dims.size() - 1];
int64_t n = x_dims.production() / h / w;

for (int64_t i = 0; i < n; i++) {
TrilTriu(x_data, diagonal, lower, h, w, out_data);
x_data += h * w;
out_data += h * w;
}
return;
}

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle

using TrilTriuFloat32 = paddle::lite::kernels::host::TrilTriuCompute<float>;
REGISTER_LITE_KERNEL(tril_triu, kHost, kAny, kNCHW, TrilTriuFloat32, float32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.Finalize();
37 changes: 37 additions & 0 deletions lite/kernels/host/tril_triu_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <class T>
class TrilTriuCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
public:
using param_t = operators::TrilTriuParam;

void Run() override;

virtual ~TrilTriuCompute() = default;
};

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ add_operator(matmul_v2_op extra SRCS matmul_v2_op.cc DEPS ${op_DEPS})
add_operator(sum_op extra SRCS sum_op.cc DEPS ${op_DEPS})
add_operator(expand_v2_op_lite extra SRCS expand_v2_op.cc DEPS ${op_DEPS})
add_operator(scatter_nd_add_op extra SRCS scatter_nd_add_op.cc DEPS ${op_DEPS})
add_operator(tril_triu_op extra SRCS tril_triu_op.cc DEPS ${op_DEPS})

# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
Expand Down
2 changes: 1 addition & 1 deletion lite/operators/expand_as_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ bool ExpandAsOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
auto Target_name = opdesc.Input("Target").front();
auto Target_name = opdesc.Input("target_tensor").front();
param_.Target = GetVar<lite::Tensor>(scope, Target_name);
return true;
}
Expand Down
8 changes: 8 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,14 @@ struct TransposeParam : ParamBase {
}
};

struct TrilTriuParam : ParamBase {
const lite::Tensor* x{nullptr};
lite::Tensor* out{nullptr};

int diagonal{0};
bool lower{true};
};

/// ----------------------- element wise operators ----------------------
struct ElementwiseParam : ParamBase {
const lite::Tensor* X{};
Expand Down
48 changes: 48 additions & 0 deletions lite/operators/tril_triu_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/tril_triu_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool TrilTriuOp::CheckShape() const {
CHECK(param_.x);
CHECK(param_.out);
return true;
}

bool TrilTriuOp::InferShapeImpl() const {
CHECK_GE(param_.x->dims().size(), 2UL);
param_.out->Resize(param_.x->dims());
param_.out->set_lod(param_.x->lod());
return true;
}

bool TrilTriuOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.x = scope->FindTensor(op_desc.Input("X").front());
param_.out = scope->FindMutableTensor(op_desc.Output("Out").front());

param_.diagonal = op_desc.GetAttr<int>("diagonal");
param_.lower = op_desc.GetAttr<bool>("lower");
return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(tril_triu, paddle::lite::operators::TrilTriuOp);
45 changes: 45 additions & 0 deletions lite/operators/tril_triu_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"

namespace paddle {
namespace lite {
namespace operators {

class TrilTriuOp : public OpLite {
public:
TrilTriuOp() {}
explicit TrilTriuOp(const std::string &op_type) : OpLite(op_type) {}

bool CheckShape() const override;

bool InferShapeImpl() const override;

bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "tril_triu"; }

private:
mutable TrilTriuParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/tests/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_sin_compute SRCS sin_compute_test.cc DEPS arena_framework ${test_kernel_deps})
lite_cc_test(test_kernel_cos_compute SRCS cos_compute_test.cc DEPS arena_framework ${test_kernel_deps})
lite_cc_test(test_kernel_tril_triu_compute SRCS tril_triu_compute_test.cc DEPS arena_framework ${test_kernel_deps})
lite_cc_test(test_kernel_pad3d_compute SRCS pad3d_compute_test.cc DEPS arena_framework ${test_kernel_deps})
lite_cc_test(test_kernel_select_input_compute SRCS select_input_compute_test.cc DEPS arena_framework ${test_kernel_deps})
# lite_cc_test(test_kernel_tensor_array_to_tensor_compute SRCS tensor_array_to_tensor_compute_test.cc DEPS arena_framework ${test_kernel_deps})
Expand Down
Loading

0 comments on commit 5f38270

Please sign in to comment.