Skip to content

Commit

Permalink
Auto-geneate kernel signature in C++ API (#39281)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg authored Jan 28, 2022
1 parent 543f3de commit fc5fa0d
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 142 deletions.
128 changes: 0 additions & 128 deletions paddle/pten/api/include/kernel_signature.h

This file was deleted.

10 changes: 5 additions & 5 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, api_item_yaml):
# args:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
Expand Down Expand Up @@ -91,8 +92,8 @@ def gene_output(self, output_type_list):

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(self.out_type_list)
return f"""
Expand All @@ -103,8 +104,8 @@ def gene_api_code(self):
{input_tensors}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}
auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
Expand Down Expand Up @@ -136,7 +137,6 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
Expand Down
19 changes: 12 additions & 7 deletions python/paddle/utils/code_gen/backward_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,24 @@ def gene_output(self, output_type_list):
output_create = ""

if len(output_type_list) == 1:
return_type = output_type_list[0]
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""

elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
{self.return_type} out({len(output_type_list)});"""

for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]'
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
output_create = output_create + f"""
out[{i}].emplace_back();"""

else:
get_out_code = f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});"""

Expand All @@ -134,8 +139,8 @@ def gene_output(self, output_type_list):

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(
self.output_type_list)
Expand All @@ -149,7 +154,8 @@ def gene_api_code(self):
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}
auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.backward_api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
Expand Down Expand Up @@ -197,7 +203,6 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
Expand Down
29 changes: 27 additions & 2 deletions python/paddle/utils/code_gen/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,21 @@ def gene_infer_meta(input_names, attr_names, infer_meta) -> str:
"""


def get_kernel_args(input_names, attrs, kernel_param):
def get_kernel_args(inputs, attrs, out_type_list, kernel_param):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<pten::DenseTensor>&',
'const std::vector<Tensor> &': 'const std::vector<pten::DenseTensor>&'
}
out_trans_map = {
'Tensor': 'pten::DenseTensor*',
'std::vector<Tensor>': 'std::vector<pten::DenseTensor*>&'
}
input_names = inputs['names']
input_infos = inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']

input_tensor_code = ""
for input_name in input_names:
# set input code
Expand All @@ -302,15 +316,26 @@ def get_kernel_args(input_names, attrs, kernel_param):
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[param]])
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::ScalarArray&')
param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::Scalar&')
param = 'pten::Scalar(' + param + ')'
else:
kernel_args_type_list.append(attrs['attr_info'][param][0])
kernel_args = kernel_args + param + ", "
elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", "
else:
kernel_args = kernel_args + str(param) + ", "
return input_tensor_code, kernel_args[:-2]

for out_type in out_type_list:
kernel_args_type_list.append(out_trans_map[out_type])

kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

return input_tensor_code, kernel_args[:-2], kernel_signature

0 comments on commit fc5fa0d

Please sign in to comment.