Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom implement for C++ API #39521

Merged
merged 21 commits into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d18dfc0
Support custom implement for C++ API
zyfncg Feb 14, 2022
046b283
rename api_invoke_impl to api_custom_impl
zyfncg Feb 14, 2022
c8b7f37
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 15, 2022
999a49a
remove manual_api
zyfncg Feb 15, 2022
edc2184
delete mutable_data in copy_to api
zyfncg Feb 15, 2022
620ed4c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 15, 2022
1d5304b
Merge branch 'custom_api_impl' of github.com:zyfncg/Paddle into custo…
zyfncg Feb 15, 2022
2cd1df7
fix problem of copy_to
zyfncg Feb 15, 2022
77853a9
add unittest for infer_meta_fn_factory
zyfncg Feb 16, 2022
39d3776
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 16, 2022
73a2a1a
fix split cofig in yaml
zyfncg Feb 16, 2022
e8d5b8e
fix split cofig in yaml
zyfncg Feb 16, 2022
23bbd71
modify sum api yaml
zyfncg Feb 16, 2022
85b19f6
add copy_to wrapped infermeta
zyfncg Feb 17, 2022
acd7577
Merge commit 'refs/pull/39703/head' of https://github.com/PaddlePaddl…
zyfncg Feb 18, 2022
972c950
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 18, 2022
76d62c7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 21, 2022
714a77e
rollback copy impl
zyfncg Feb 23, 2022
06f9a20
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 23, 2022
4d8fe8c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 24, 2022
6381b03
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions paddle/pten/api/include/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,5 @@ namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual_api和api_custom_impl定位有点模糊了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. 已将manual_api删掉 thx~


// TODO(chentianyu03): Split API has extra logic to calculate the outputs size,
// api_gen do not support
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);

} // namespace experimental
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/pten/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ add_custom_command(
VERBATIM)

cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform api_custom_impl)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api api_custom_impl)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten)
102 changes: 102 additions & 0 deletions paddle/pten/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/api/lib/api_custom_impl.h"

#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/common/backend.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
#include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/infermeta/unary.h"

#include "glog/logging.h"

namespace paddle {
namespace experimental {

PADDLE_API std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto dense_x = PrepareData(x, kernel.InputAt(0), {});

// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}

std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<pten::MetaTensor> meta_outs;
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
}

pten::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const pten::DenseTensor&,
const pten::ScalarArray&,
const pten::Scalar&,
std::vector<pten::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*dense_x,
pten::ScalarArray(num_or_sections),
pten::Scalar(axis),
dense_outs);

return out;
}

} // namespace experimental
} // namespace paddle
29 changes: 29 additions & 0 deletions paddle/pten/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

namespace paddle {
namespace experimental {

PADDLE_API std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);

} // namespace experimental
} // namespace paddle
65 changes: 0 additions & 65 deletions paddle/pten/api/lib/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,71 +78,6 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
return out;
}

PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto dense_x = PrepareData(x, kernel.InputAt(0), {});

// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}

std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<pten::MetaTensor> meta_outs;
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
}

pten::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const pten::DenseTensor&,
const pten::ScalarArray&,
const pten::Scalar&,
std::vector<pten::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*dense_x,
pten::ScalarArray(num_or_sections),
pten::Scalar(axis),
dense_outs);

return out;
}
} // namespace experimental
} // namespace paddle

Expand Down
1 change: 0 additions & 1 deletion paddle/pten/tests/api/test_split_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

#include "paddle/pten/api/include/api.h"

#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@
kernel :
func : sign

- api : split
args : (const Tensor& x, const ScalarArray& num_or_sections, const Scalar& axis)
output : std::vector<Tensor>
invoke : split_impl(x, num_or_sections, axis)

- api : subtract
args : (const Tensor& x, const Tensor& y)
output : Tensor
Expand Down
1 change: 1 addition & 0 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def source_include(header_file_path):

#include "glog/logging.h"

#include "paddle/pten/api/lib/api_custom_impl.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
Expand Down
1 change: 1 addition & 0 deletions python/paddle/utils/code_gen/backward_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def source_include(header_file_path):

#include "glog/logging.h"

#include "paddle/pten/api/lib/api_custom_impl.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
Expand Down