Skip to content

Commit

Permalink
Auto-geneate kernel signature in C++ API
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Jan 27, 2022
1 parent 80dfa01 commit b8c2dac
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 142 deletions.
128 changes: 0 additions & 128 deletions paddle/pten/api/include/kernel_signature.h

This file was deleted.

10 changes: 5 additions & 5 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, api_item_yaml):
# args:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
Expand Down Expand Up @@ -91,8 +92,8 @@ def gene_output(self, output_type_list):

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(self.out_type_list)
return f"""
Expand All @@ -103,8 +104,8 @@ def gene_api_code(self):
{input_tensors}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}
auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
Expand Down Expand Up @@ -136,7 +137,6 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
Expand Down
19 changes: 12 additions & 7 deletions python/paddle/utils/code_gen/backward_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,24 @@ def gene_output(self, output_type_list):
output_create = ""

if len(output_type_list) == 1:
return_type = output_type_list[0]
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""

elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
{self.return_type} out({len(output_type_list)});"""

for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]'
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
output_create = output_create + f"""
out[{i}].emplace_back();"""

else:
get_out_code = f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});"""

Expand All @@ -134,8 +139,8 @@ def gene_output(self, output_type_list):

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(
self.output_type_list)
Expand All @@ -149,7 +154,8 @@ def gene_api_code(self):
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}
auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.backward_api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
Expand Down Expand Up @@ -197,7 +203,6 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
Expand Down
29 changes: 27 additions & 2 deletions python/paddle/utils/code_gen/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,21 @@ def gene_infer_meta(input_names, attr_names, infer_meta) -> str:
"""


def get_kernel_args(input_names, attrs, kernel_param):
def get_kernel_args(inputs, attrs, out_type_list, kernel_param):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<pten::DenseTensor>&',
'const std::vector<Tensor> &': 'const std::vector<pten::DenseTensor>&'
}
out_trans_map = {
'Tensor': 'pten::DenseTensor*',
'std::vector<Tensor>': 'std::vector<pten::DenseTensor*>&'
}
input_names = inputs['names']
input_infos = inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']

input_tensor_code = ""
for input_name in input_names:
# set input code
Expand All @@ -302,15 +316,26 @@ def get_kernel_args(input_names, attrs, kernel_param):
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[param]])
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::ScalarArray&')
param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::Scalar&')
param = 'pten::Scalar(' + param + ')'
else:
kernel_args_type_list.append(attrs['attr_info'][param][0])
kernel_args = kernel_args + param + ", "
elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", "
else:
kernel_args = kernel_args + str(param) + ", "
return input_tensor_code, kernel_args[:-2]

for out_type in out_type_list:
kernel_args_type_list.append(out_trans_map[out_type])

kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

return input_tensor_code, kernel_args[:-2], kernel_signature

1 comment on commit b8c2dac

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.