diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc index bbd5b9c4e7db9..d0290795455db 100644 --- a/paddle/fluid/operators/searchsorted_op.cc +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/searchsorted_op.h" - +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker); - -REGISTER_OP_CPU_KERNEL( - searchsorted, - ops::SearchSortedKernel, - ops::SearchSortedKernel, - ops::SearchSortedKernel, - ops::SearchSortedKernel); diff --git a/paddle/phi/kernels/cpu/searchsorted_kernel.cc b/paddle/phi/kernels/cpu/searchsorted_kernel.cc new file mode 100644 index 0000000000000..c036c2d438a36 --- /dev/null +++ b/paddle/phi/kernels/cpu/searchsorted_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/searchsorted_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h" + +PD_REGISTER_KERNEL(searchsorted, + CPU, + ALL_LAYOUT, + phi::SearchsortedKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/searchsorted_kernel.cu b/paddle/phi/kernels/gpu/searchsorted_kernel.cu new file mode 100644 index 0000000000000..4a2ce2241c22d --- /dev/null +++ b/paddle/phi/kernels/gpu/searchsorted_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/searchsorted_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h" + +PD_REGISTER_KERNEL(searchsorted, + GPU, + ALL_LAYOUT, + phi::SearchsortedKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/phi/kernels/impl/searchsorted_kernel_impl.h similarity index 58% rename from paddle/fluid/operators/searchsorted_op.h rename to paddle/phi/kernels/impl/searchsorted_kernel_impl.h index 6aa38a8158132..82bd9fba2a66d 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/phi/kernels/impl/searchsorted_kernel_impl.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,16 +16,11 @@ #include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/for_range.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/funcs/algorithm.h" +#include "paddle/phi/kernels/funcs/for_range.h" -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; +namespace phi { template class GpuAndCpuSearchSortedCompute { @@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute { static HOSTDEVICE bool IsInf(int64_t x) { return false; } HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data, - const T2* value_data, bool right, + const T2* value_data, + bool right, bool is_1d_boundaries, - int64_t val_size, int64_t seq_size, + int64_t val_size, + int64_t seq_size, OutType* out_data) : sequence_data_(sequence_data), value_data_(value_data), @@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute { OutType* out_data_; }; -template +template class SearchSortedFunctor { public: - SearchSortedFunctor(const framework::ExecutionContext& context, - const framework::Tensor* sorted_sequence, - const framework::Tensor* value, bool right, + SearchSortedFunctor(const Context& context, + const DenseTensor* sorted_sequence, + const DenseTensor* value, + bool right, OutType* out_data) : context_(context), sorted_sequence_(sorted_sequence), @@ -121,74 +119,73 @@ class SearchSortedFunctor { void apply() { const T1* sequence_data = sorted_sequence_->data(); const T2* value_data = value_->data(); - const framework::DDim& seq_dims = sorted_sequence_->dims(); - const framework::DDim& val_dims = value_->dims(); + const phi::DDim& seq_dims = sorted_sequence_->dims(); + const phi::DDim& val_dims = value_->dims(); bool is_1d_boundaries = seq_dims.size() == 1; int64_t val_size = val_dims[val_dims.size() - 1]; int64_t seq_size = seq_dims[seq_dims.size() - 1]; - auto& dev_ctx = context_.template device_context(); - platform::ForRange for_range(dev_ctx, value_->numel()); + funcs::ForRange for_range(context_, value_->numel()); GpuAndCpuSearchSortedCompute - gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_, - is_1d_boundaries, val_size, seq_size, + gpu_and_cpu_search_sorted_compute(sequence_data, + value_data, + right_, + is_1d_boundaries, + val_size, + seq_size, out_data_); for_range(gpu_and_cpu_search_sorted_compute); } private: - const framework::ExecutionContext& context_; - const framework::Tensor* sorted_sequence_; - const framework::Tensor* value_; + const Context& context_; + const DenseTensor* sorted_sequence_; + const DenseTensor* value_; bool right_; OutType* out_data_; }; template -static void VisitDataType(framework::proto::VarType::Type type, - Visitor visitor) { - if (type == framework::proto::VarType::FP32) { +static void VisitDataType(DataType type, Visitor visitor) { + if (type == DataType::FLOAT32) { visitor.template apply(); - } else if (type == framework::proto::VarType::FP64) { + } else if (type == DataType::FLOAT64) { visitor.template apply(); - } else if (type == framework::proto::VarType::INT32) { + } else if (type == DataType::INT32) { visitor.template apply(); - } else if (type == framework::proto::VarType::INT64) { + } else if (type == DataType::INT64) { visitor.template apply(); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(errors::InvalidArgument( "The recieved values data type %s can not meet input requirements. " "Because the given values data type of searchsorted operators must be " "float32, float64, int32 or int64. Please input appropriate " "sorted_sequence again! ", - framework::DataTypeToString(type))); + type)); } } -template -class SearchSortedKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* sorted_sequence = context.Input("SortedSequence"); - auto* value = context.Input("Values"); - bool out_int32 = context.Attr("out_int32"); - bool right = context.Attr("right"); - auto* out = context.Output("Out"); - - if (out_int32) { - int* out_data = out->mutable_data(context.GetPlace()); - SearchSortedFunctor functor( - context, sorted_sequence, value, right, out_data); - VisitDataType(framework::TransToProtoVarType(value->dtype()), functor); - } else { - int64_t* out_data = out->mutable_data(context.GetPlace()); - SearchSortedFunctor functor( - context, sorted_sequence, value, right, out_data); - VisitDataType(framework::TransToProtoVarType(value->dtype()), functor); - } +template +void SearchsortedKernel(const Context& ctx, + const DenseTensor& sorted_sequence, + const DenseTensor& value, + bool out_int32, + bool right, + DenseTensor* out) { + if (out_int32) { + ctx.template Alloc(out); + int* out_data = out->data(); + SearchSortedFunctor functor( + ctx, &sorted_sequence, &value, right, out_data); + VisitDataType(value.dtype(), functor); + } else { + ctx.template Alloc(out); + int64_t* out_data = out->data(); + SearchSortedFunctor functor( + ctx, &sorted_sequence, &value, right, out_data); + VisitDataType(value.dtype(), functor); } -}; +} -} // namespace operators -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/operators/searchsorted_op.cu b/paddle/phi/kernels/searchsorted_kernel.h similarity index 54% rename from paddle/fluid/operators/searchsorted_op.cu rename to paddle/phi/kernels/searchsorted_kernel.h index 4633ab43efba1..e425c7fd79555 100644 --- a/paddle/fluid/operators/searchsorted_op.cu +++ b/paddle/phi/kernels/searchsorted_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/searchsorted_op.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; +#pragma once -REGISTER_OP_CUDA_KERNEL( - searchsorted, ops::SearchSortedKernel, - ops::SearchSortedKernel, - ops::SearchSortedKernel, - ops::SearchSortedKernel); +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SearchsortedKernel(const Context& ctx, + const DenseTensor& sorted_sequence, + const DenseTensor& value, + bool out_int32, + bool right, + DenseTensor* out); + +} // namespace phi