Skip to content

Commit

Permalink
[host] add linspace op,test=develop (#5601)
Browse files Browse the repository at this point in the history
  • Loading branch information
Leonardo-Ding authored Mar 4, 2021
1 parent 40d02c1 commit acd40c8
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 0 deletions.
1 change: 1 addition & 0 deletions lite/kernels/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_kernel(scatter_nd_add_compute_host Host extra SRCS scatter_nd_add_compute.cc
add_kernel(tril_triu_compute_host Host extra SRCS tril_triu_compute.cc DEPS ${lite_kernel_deps})
add_kernel(topk_v2_compute_host Host extra SRCS topk_v2_compute.cc DEPS ${lite_kernel_deps})
add_kernel(meshgrid_compute_host Host extra SRCS meshgrid_compute.cc DEPS ${lite_kernel_deps})
add_kernel(linspace_compute_host Host extra SRCS linspace_compute.cc DEPS ${lite_kernel_deps})

if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
Expand Down
105 changes: 105 additions & 0 deletions lite/kernels/host/linspace_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/host/linspace_compute.h"
#include <vector>

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <typename Tin, typename Tout>
static void LinspaceFunc(const operators::LinspaceParam& param) {
const auto* start_tensor = param.Start;
const auto* stop_tensor = param.Stop;
const auto* num_tensor = param.Num;
auto* out_tensor = param.Out;
const Tout start = static_cast<Tout>(start_tensor->template data<Tin>()[0]);
const Tout stop = static_cast<Tout>(stop_tensor->template data<Tin>()[0]);
const int num = num_tensor->data<int>()[0];
Tout* out_data = out_tensor->template mutable_data<Tout>();

if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop - start)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
if (i < half_num) {
out_data[i] = static_cast<Tout>(start + step * i);
} else {
out_data[i] = static_cast<Tout>(stop - step * (num - i - 1));
}
}
} else {
out_data[0] = static_cast<Tout>(start);
}
}

template <typename T, PrecisionType PType>
void LinspaceCompute<T, PType>::Run() {
auto& param = this->template Param<operators::LinspaceParam>();
switch (param.Out->precision()) {
case PRECISION(kFloat):
LinspaceFunc<T, float>(param);
break;
case PRECISION(kInt32):
LinspaceFunc<T, int32_t>(param);
break;
default:
LOG(FATAL) << "Linspace op unsupport output data type: "
<< lite_api::PrecisionToStr(param.Out->precision());
}
return;
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle

using linspace_float =
paddle::lite::kernels::host::LinspaceCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(linspace, kHost, kFloat, kAny, linspace_float, float32)
.BindInput("Start",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindInput("Stop",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindInput("Num",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();

using linspace_int32 =
paddle::lite::kernels::host::LinspaceCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(linspace, kHost, kInt32, kAny, linspace_int32, int32)
.BindInput("Start",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindInput("Stop",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindInput("Num",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
36 changes: 36 additions & 0 deletions lite/kernels/host/linspace_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <typename T, PrecisionType PType>
class LinspaceCompute
: public KernelLite<TARGET(kHost), PType, DATALAYOUT(kAny)> {
public:
void Run() override;

virtual ~LinspaceCompute() = default;
};

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ add_operator(tile_op extra SRCS tile_op.cc DEPS ${op_DEPS})
add_operator(scatter_nd_add_op extra SRCS scatter_nd_add_op.cc DEPS ${op_DEPS})
add_operator(tril_triu_op extra SRCS tril_triu_op.cc DEPS ${op_DEPS})
add_operator(meshgrid_op_lite extra SRCS meshgrid_op.cc DEPS ${op_DEPS})
add_operator(linspace_op extra SRCS linspace_op.cc DEPS ${op_DEPS})

# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
Expand Down
77 changes: 77 additions & 0 deletions lite/operators/linspace_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/linspace_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool LinspaceOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Start);
CHECK_OR_FALSE(param_.Stop);
CHECK_OR_FALSE(param_.Num);
CHECK_OR_FALSE(param_.Out);

int start_dims_size = param_.Start->dims().size();
CHECK_EQ(start_dims_size, 1) << "The shape of input start must be 1.";
int stop_dims_size = param_.Stop->dims().size();
CHECK_EQ(stop_dims_size, 1) << "The shape of input stop must be 1.";
int num_dims_size = param_.Num->dims().size();
CHECK_EQ(num_dims_size, 1) << "The shape of input num must be 1.";

return true;
}

bool LinspaceOpLite::InferShapeImpl() const {
// param_.dtype(int) is defined in paddle/fluid/framework/framework.proto
// param_.dtype(int) means output dtype and lite supports fp32/int32.
// if param_.dtype is not defined, output dtype is fp32.
switch (param_.dtype) {
case 2:
param_.Out->set_precision(PRECISION(kInt32));
break;
case 5:
param_.Out->set_precision(PRECISION(kFloat));
break;
default:
param_.Out->set_precision(PRECISION(kFloat));
break;
}
param_.Out->Resize(param_.Num->dims());
return true;
}

bool LinspaceOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto start_name = opdesc.Input("Start").front();
auto stop_name = opdesc.Input("Stop").front();
auto num_name = opdesc.Input("Num").front();
auto Out_name = opdesc.Output("Out").front();
param_.Start = GetVar<lite::Tensor>(scope, start_name);
param_.Stop = GetVar<lite::Tensor>(scope, stop_name);
param_.Num = GetVar<lite::Tensor>(scope, num_name);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);

if (opdesc.HasAttr("dtype")) {
param_.dtype = opdesc.GetAttr<int>("dtype");
}
return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(linspace, paddle::lite::operators::LinspaceOpLite);
44 changes: 44 additions & 0 deletions lite/operators/linspace_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"

namespace paddle {
namespace lite {
namespace operators {

class LinspaceOpLite : public OpLite {
public:
LinspaceOpLite() {}
explicit LinspaceOpLite(const std::string &op_type) : OpLite(op_type) {}

bool CheckShape() const override;

bool InferShapeImpl() const override;

bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "linspace"; }

private:
mutable LinspaceParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
8 changes: 8 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,14 @@ struct PNormParam : ParamBase {
bool keepdim{false};
bool asvector{false};
};

struct LinspaceParam : ParamBase {
const lite::Tensor* Start{};
const lite::Tensor* Stop{};
const lite::Tensor* Num{};
lite::Tensor* Out{};
int dtype{};
};
} // namespace operators
} // namespace lite
} // namespace paddle

0 comments on commit acd40c8

Please sign in to comment.