Skip to content

Commit

Permalink
[Phi]Move searchsorted kernel to phi (#40520)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean authored Mar 15, 2022
1 parent 1a32391 commit 85f8fd9
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 73 deletions.
10 changes: 1 addition & 9 deletions paddle/fluid/operators/searchsorted_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/searchsorted_op.h"

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
Expand Down Expand Up @@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;

REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker);

REGISTER_OP_CPU_KERNEL(
searchsorted,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, float>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, double>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int64_t>);
28 changes: 28 additions & 0 deletions paddle/phi/kernels/cpu/searchsorted_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/searchsorted_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"

PD_REGISTER_KERNEL(searchsorted,
CPU,
ALL_LAYOUT,
phi::SearchsortedKernel,
float,
double,
int,
int64_t) {}
28 changes: 28 additions & 0 deletions paddle/phi/kernels/gpu/searchsorted_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/searchsorted_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"

PD_REGISTER_KERNEL(searchsorted,
GPU,
ALL_LAYOUT,
phi::SearchsortedKernel,
float,
double,
int,
int64_t) {}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,16 +16,11 @@

#include <math.h>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace phi {

template <typename T1, typename T2, typename OutType>
class GpuAndCpuSearchSortedCompute {
Expand Down Expand Up @@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute {
static HOSTDEVICE bool IsInf(int64_t x) { return false; }

HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data,
const T2* value_data, bool right,
const T2* value_data,
bool right,
bool is_1d_boundaries,
int64_t val_size, int64_t seq_size,
int64_t val_size,
int64_t seq_size,
OutType* out_data)
: sequence_data_(sequence_data),
value_data_(value_data),
Expand Down Expand Up @@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute {
OutType* out_data_;
};

template <typename DeviceContext, typename T1, typename OutType>
template <typename Context, typename T1, typename OutType>
class SearchSortedFunctor {
public:
SearchSortedFunctor(const framework::ExecutionContext& context,
const framework::Tensor* sorted_sequence,
const framework::Tensor* value, bool right,
SearchSortedFunctor(const Context& context,
const DenseTensor* sorted_sequence,
const DenseTensor* value,
bool right,
OutType* out_data)
: context_(context),
sorted_sequence_(sorted_sequence),
Expand All @@ -121,74 +119,73 @@ class SearchSortedFunctor {
void apply() {
const T1* sequence_data = sorted_sequence_->data<T1>();
const T2* value_data = value_->data<T2>();
const framework::DDim& seq_dims = sorted_sequence_->dims();
const framework::DDim& val_dims = value_->dims();
const phi::DDim& seq_dims = sorted_sequence_->dims();
const phi::DDim& val_dims = value_->dims();

bool is_1d_boundaries = seq_dims.size() == 1;
int64_t val_size = val_dims[val_dims.size() - 1];
int64_t seq_size = seq_dims[seq_dims.size() - 1];

auto& dev_ctx = context_.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, value_->numel());
funcs::ForRange<Context> for_range(context_, value_->numel());
GpuAndCpuSearchSortedCompute<T1, T2, OutType>
gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_,
is_1d_boundaries, val_size, seq_size,
gpu_and_cpu_search_sorted_compute(sequence_data,
value_data,
right_,
is_1d_boundaries,
val_size,
seq_size,
out_data_);
for_range(gpu_and_cpu_search_sorted_compute);
}

private:
const framework::ExecutionContext& context_;
const framework::Tensor* sorted_sequence_;
const framework::Tensor* value_;
const Context& context_;
const DenseTensor* sorted_sequence_;
const DenseTensor* value_;
bool right_;
OutType* out_data_;
};

template <typename Visitor>
static void VisitDataType(framework::proto::VarType::Type type,
Visitor visitor) {
if (type == framework::proto::VarType::FP32) {
static void VisitDataType(DataType type, Visitor visitor) {
if (type == DataType::FLOAT32) {
visitor.template apply<float>();
} else if (type == framework::proto::VarType::FP64) {
} else if (type == DataType::FLOAT64) {
visitor.template apply<double>();
} else if (type == framework::proto::VarType::INT32) {
} else if (type == DataType::INT32) {
visitor.template apply<int>();
} else if (type == framework::proto::VarType::INT64) {
} else if (type == DataType::INT64) {
visitor.template apply<int64_t>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"The recieved values data type %s can not meet input requirements. "
"Because the given values data type of searchsorted operators must be "
"float32, float64, int32 or int64. Please input appropriate "
"sorted_sequence again! ",
framework::DataTypeToString(type)));
type));
}
}

template <typename DeviceContext, typename T>
class SearchSortedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* sorted_sequence = context.Input<Tensor>("SortedSequence");
auto* value = context.Input<Tensor>("Values");
bool out_int32 = context.Attr<bool>("out_int32");
bool right = context.Attr<bool>("right");
auto* out = context.Output<Tensor>("Out");

if (out_int32) {
int* out_data = out->mutable_data<int>(context.GetPlace());
SearchSortedFunctor<DeviceContext, T, int> functor(
context, sorted_sequence, value, right, out_data);
VisitDataType(framework::TransToProtoVarType(value->dtype()), functor);
} else {
int64_t* out_data = out->mutable_data<int64_t>(context.GetPlace());
SearchSortedFunctor<DeviceContext, T, int64_t> functor(
context, sorted_sequence, value, right, out_data);
VisitDataType(framework::TransToProtoVarType(value->dtype()), functor);
}
template <typename T, typename Context>
void SearchsortedKernel(const Context& ctx,
const DenseTensor& sorted_sequence,
const DenseTensor& value,
bool out_int32,
bool right,
DenseTensor* out) {
if (out_int32) {
ctx.template Alloc<int>(out);
int* out_data = out->data<int>();
SearchSortedFunctor<Context, T, int> functor(
ctx, &sorted_sequence, &value, right, out_data);
VisitDataType(value.dtype(), functor);
} else {
ctx.template Alloc<int64_t>(out);
int64_t* out_data = out->data<int64_t>();
SearchSortedFunctor<Context, T, int64_t> functor(
ctx, &sorted_sequence, &value, right, out_data);
VisitDataType(value.dtype(), functor);
}
};
}

} // namespace operators
} // namespace paddle
} // namespace phi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,12 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/searchsorted_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#pragma once

REGISTER_OP_CUDA_KERNEL(
searchsorted, ops::SearchSortedKernel<plat::CUDADeviceContext, float>,
ops::SearchSortedKernel<plat::CUDADeviceContext, double>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int64_t>);
#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void SearchsortedKernel(const Context& ctx,
const DenseTensor& sorted_sequence,
const DenseTensor& value,
bool out_int32,
bool right,
DenseTensor* out);

} // namespace phi

0 comments on commit 85f8fd9

Please sign in to comment.