Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Pten】Auto-Generate InterMeta register #39436

Merged
merged 8 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions paddle/pten/core/infermeta_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx);
}

std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
result.reserve(end - start);

for (size_t i = start; i < end; ++i) {
result.emplace_back(*inputs_.at(i));
}

return result;
}

MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}
Expand Down
38 changes: 35 additions & 3 deletions paddle/pten/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License. */
#include <string>
#include <utility>

#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/macros.h"
#include "paddle/pten/core/meta_tensor.h"
Expand Down Expand Up @@ -46,6 +48,7 @@ class InferMetaContext {

const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);

template <typename AttrType>
Expand Down Expand Up @@ -85,7 +88,8 @@ class InferMetaContext {
"InferMeta's Attributes should appear before Outputs."); \
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(pargs..., \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
Expand Down Expand Up @@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};

template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};

PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);

// TODO(chenweihang): support vector<MetaTensor> input later

template <typename... Tail>
Expand Down Expand Up @@ -227,7 +260,6 @@ struct InferMetaFnRegistrar {
"PT_REGISTER_INFER_META_FN must be called in global namespace."); \
static const ::pten::InferMetaFnRegistrar \
__registrar_arg_map_fn_for_##kernel_name_prefix( \
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \
int TouchInferMetaFnSymbol_##op_type() { return 0; }
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn))

} // namespace pten
10 changes: 5 additions & 5 deletions paddle/pten/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License. */

namespace pten {

void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
auto out_dims = pten::framework::make_ddim(shape);
out->set_dims(out_dims);
out->set_dtype(dtype);
Expand All @@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
CreateInferMeta(shape.GetData(), dtype, layout, out);
CreateInferMetaBase(shape.GetData(), dtype, layout, out);
}

} // namespace pten
8 changes: 4 additions & 4 deletions paddle/pten/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ namespace pten {
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.

void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);

void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
Expand Down
16 changes: 7 additions & 9 deletions paddle/pten/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out));
ReduceInferMetaBase(x, axis, keep_dim, dtype, out);
}

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out) {
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out) {
bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) {
Expand Down Expand Up @@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out);
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out);
}

void TransferLayoutInferMeta(const MetaTensor& x,
Expand All @@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x,
}

} // namespace pten

PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);
10 changes: 5 additions & 5 deletions paddle/pten/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out);

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/math_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) {
auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out);
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out;
}
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@
kernel :
func : scale

- api : sign
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sign

- api : subtract
args : (const Tensor& x, const Tensor& y)
output : Tensor
Expand All @@ -173,10 +181,10 @@
- api : sum
args : (const Tensor& x, const std::vector<int64_t>& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor
infer_meta :
infer_meta :
func : SumInferMeta
param: [x, axis, dtype, keep_dim]
kernel :
kernel :
func : sum
param : [x, axis, dtype, keep_dim]
data_type : x
Expand Down
11 changes: 2 additions & 9 deletions python/paddle/utils/code_gen/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,7 @@ def get_kernel_args(self):
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']

input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""

attr_names = self.attrs['names']

kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
Expand All @@ -401,11 +394,11 @@ def get_kernel_args(self):
elif input_name in self.data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""

else:
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""

kernel_args = "*dev_ctx, "
for param in kernel_param:
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def gene_output(self, output_type_list):

return kernel_output, output_names, output_create

def gene_infer_meta_register(self):
if self.is_base_api:
return f"""
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""

else:
return ''


def header_include():
return """
Expand All @@ -83,6 +91,7 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
Expand Down Expand Up @@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])

infer_meta_register_code = ''

for api in apis:
api_code = ForwardAPI(api)
print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code())
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
)

header_file.write(namespace[1])
source_file.write(namespace[1])

source_file.write(api_register())
source_file.write(infer_meta_register_code)

header_file.close()
source_file.close()
Expand Down