Skip to content

Commit

Permalink
[PHI]Optimizer kernel args parser (#57445)
Browse files Browse the repository at this point in the history
* optmize kernel args parser

* remove usless code

* polish code

* update

* update
  • Loading branch information
phlrain authored Sep 19, 2023
1 parent d29524f commit 0eaf59a
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 207 deletions.
1 change: 1 addition & 0 deletions paddle/phi/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ collect_srcs(
mixed_vector.cc
generator.cc
kernel_factory.cc
kernel_registry.cc
tensor_utils.cc
utils/type_info.cc)

Expand Down
233 changes: 233 additions & 0 deletions paddle/phi/core/kernel_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"

#include <typeindex>
#include <typeinfo>

#include "paddle/phi/core/custom_kernel.h"
#include "paddle/phi/core/kernel_utils.h"

namespace phi {

void SetKernelArgsDef(const std::vector<std::type_index>& args_type,
const KernelKey& default_key,
KernelArgsDef* args_def) {
auto default_tensor_layout = phi::DataLayout::NCHW;
if (default_key.layout() != phi::DataLayout::ANY) {
default_tensor_layout = default_key.layout();
}
for (auto arg_type : args_type) {
if (arg_type == std::type_index(typeid(const CPUContext&))
#if defined(PADDLE_WITH_DNNL)
|| arg_type == std::type_index(typeid(const OneDNNContext&))
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
|| arg_type == std::type_index(typeid(const GPUContext&))
#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| arg_type == std::type_index(typeid(const XPUContext&))
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_KP)
|| arg_type == std::type_index(typeid(const KPSContext&))
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
|| arg_type == std::type_index(typeid(const CustomContext&))) {
#else
) {
#endif
// do nothing, skip context arg now
} else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const paddle::optional<DenseTensor>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(
const paddle::optional<std::vector<const DenseTensor*>>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const paddle::optional<SelectedRows>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const DenseTensor*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const phi::ExtendedTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const ExtendedTensor*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const SelectedRows*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const std::vector<const TensorBase*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const TensorArray*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const StringTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SparseCooTensor&>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SparseCsrTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SparseCsrTensor&>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const TensorArray&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(std::vector<DenseTensor*>))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(TensorArray*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(SparseCooTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(SparseCsrTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(StringTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(ExtendedTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(bool))) {
args_def->AppendAttribute(AttributeType::BOOL);
} else if (arg_type == std::type_index(typeid(int))) {
args_def->AppendAttribute(AttributeType::INT32);
} else if (arg_type == std::type_index(typeid(int64_t))) {
args_def->AppendAttribute(AttributeType::INT64);
} else if (arg_type == std::type_index(typeid(float))) {
args_def->AppendAttribute(AttributeType::FLOAT32);
} else if (arg_type == std::type_index(typeid(double))) {
args_def->AppendAttribute(AttributeType::FLOAT64);
} else if (arg_type == std::type_index(typeid(std::string))) {
args_def->AppendAttribute(AttributeType::STRING);
} else if (arg_type == std::type_index(typeid(const std::vector<bool>&))) {
args_def->AppendAttribute(AttributeType::BOOLS);
} else if (arg_type == std::type_index(typeid(const std::vector<int>&))) {
args_def->AppendAttribute(AttributeType::INT32S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<int64_t>&))) {
args_def->AppendAttribute(AttributeType::INT64S);
} else if (arg_type == std::type_index(typeid(const std::vector<float>&))) {
args_def->AppendAttribute(AttributeType::FLOAT32S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<double>&))) {
args_def->AppendAttribute(AttributeType::FLOAT64S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<std::string>&))) {
args_def->AppendAttribute(AttributeType::STRINGS);
} else if (arg_type == std::type_index(typeid(const Scalar&))) {
args_def->AppendAttribute(AttributeType::SCALAR);
} else if (arg_type ==
std::type_index(typeid(const std::vector<Scalar>&))) {
args_def->AppendAttribute(AttributeType::SCALARS);
} else if (arg_type == std::type_index(typeid(const IntArray&))) {
args_def->AppendAttribute(AttributeType::INT_ARRAY);
} else if (arg_type == std::type_index(typeid(DataType))) {
args_def->AppendAttribute(AttributeType::DATA_TYPE);
} else if (arg_type == std::type_index(typeid(DataLayout))) {
args_def->AppendAttribute(AttributeType::DATA_LAYOUT);
} else if (arg_type == std::type_index(typeid(Place))) {
args_def->AppendAttribute(AttributeType::PLACE);
} else {
PADDLE_THROW(phi::errors::Unavailable(
"Unsupported kernel argument type `%s`.", arg_type.name()));
}
}
}
} // namespace phi
Loading

0 comments on commit 0eaf59a

Please sign in to comment.