Skip to content

Commit

Permalink
[PRIM][PIR]Support optional and more vjp gen (#57693)
Browse files Browse the repository at this point in the history
* Support optional input and output for pir api

* Polish gen code

* Fix py3 ci compile error

* Fix code

* Fix error

* Fix windows ci error

* support optional in vjp and add more ops vjp gen

---------

Co-authored-by: 0x45f <wangzhen45@baidu.com>
  • Loading branch information
Charles-hit and 0x45f authored Sep 25, 2023
1 parent 16a45d7 commit 0932142
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 25 deletions.
5 changes: 4 additions & 1 deletion paddle/fluid/operators/generator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
input_types_map,
opmaker_attr_types_map,
optional_input_types_map,
optional_output_type_map,
output_type_map,
phi_attr_types_map,
sr_output_types_map,
Expand Down Expand Up @@ -154,7 +155,9 @@ def delete_last_underline(op_name):


# ------------------------------ output ----------------------------------
def to_paddle_output_type(s):
def to_paddle_output_type(s, optional=False):
if optional:
return optional_output_type_map[s]
return output_type_map[s]


Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/generator/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@
'SelectedRows': 'SelectedRows',
}

optional_output_type_map = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
}

# ------------------------------ phi attr ------------------------------
phi_attr_types_map = attr_types_map.copy()
phi_attr_types_map.update(
Expand Down
108 changes: 106 additions & 2 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,60 @@
'unsqueeze_grad',
'poisson_grad',
'gumbel_softmax_grad',
'squeeze_grad',
'unsqueeze_grad',
'conv2d_grad',
'depthwise_conv2d_grad',
'sqrt_grad',
'flatten_grad',
'relu_grad',
'abs_grad',
'log_grad',
'clip_grad',
'ceil_grad',
'frobenius_norm_grad',
'p_norm_grad',
'maximum_grad',
'argsort_grad',
'min_grad',
'batch_norm_grad',
'max_pool2d_with_index_grad',
'pool2d_grad',
'minimum_grad',
'prod_grad',
'round_grad',
'sin_grad',
'cos_grad',
'dot_grad',
'floor_grad',
'topk_grad',
'square_grad',
'gather_grad',
'label_smooth_grad',
'cross_entropy_with_softmax_grad',
'mean_all_grad',
'cumsum_grad',
'linear_interp_grad',
'bilinear_interp_grad',
'trilinear_interp_grad',
'nearest_interp_grad',
'bicubic_interp_grad',
'assign_grad',
'assign_out__grad',
'real_grad',
'flip_grad',
'softmax_grad',
'expand_grad',
'conv2d_transpose_grad',
'depthwise_conv2d_transpose_grad',
'sigmoid_grad',
'pad_grad',
'pad3d_grad',
'einsum_grad',
'leaky_relu_grad',
'log10_grad',
'conv3d_grad',
'solve_grad',
'diag_grad',
'trace_grad',
]


Expand Down Expand Up @@ -183,6 +235,58 @@
'stack_grad',
'squeeze_grad',
'unsqueeze_grad',
'conv2d_grad',
'depthwise_conv2d_grad',
'sqrt_grad',
'flatten_grad',
'relu_grad',
'abs_grad',
'log_grad',
'clip_grad',
'ceil_grad',
'frobenius_norm_grad',
'p_norm_grad',
'maximum_grad',
'argsort_grad',
'min_grad',
'batch_norm_grad',
'max_pool2d_with_index_grad',
'pool2d_grad',
'minimum_grad',
'prod_grad',
'round_grad',
'sin_grad',
'cos_grad',
'dot_grad',
'floor_grad',
'topk_grad',
'square_grad',
'gather_grad',
'label_smooth_grad',
'cross_entropy_with_softmax_grad',
'mean_all_grad',
'cumsum_grad',
'linear_interp_grad',
'bilinear_interp_grad',
'trilinear_interp_grad',
'nearest_interp_grad',
'bicubic_interp_grad',
'assign_out__grad',
'real_grad',
'softmax_grad',
'conv2d_transpose_grad',
'depthwise_conv2d_transpose_grad',
'sigmoid_grad',
'pad_grad',
'pad3d_grad',
'einsum_grad',
'leaky_relu_grad',
'log10_grad',
'conv3d_grad',
'solve_grad',
'diag_grad',
'trace_grad',
'flip',
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ using DataType = phi::DataType;

{% for api in apis %}
{%- if api.name in backend_white_list -%}
{% set inplace_map = {} %}
{% if 'inplace' in api and api.inplace != None %}
{% for source, target in api.inplace.items() %}
{% do inplace_map.update({source: target}) %}
{% endfor %}
{% endif %}
{% if api.attrs is exist_mutable_attribute %}
{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, True, True)}};
{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, inplace_map, True, True)}};

{% endif %}
{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, False, True)}};
{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, inplace_map, False, True)}};

{% endif %}
{% endfor %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace backend {
{{common.sequence('', '', ', ', attrs)}}
{%- endmacro -%}

{%- macro sig(name, inputs, attrs, outputs) -%}
{%- macro sig(name, inputs, attrs, outputs, inplace_map) -%}
template <>
{{common.ret(outputs)}} {{name}}<Tensor>({{common.params(inputs, attrs, False)}})
{{common.ret(outputs, inplace_map)}} {{name}}<Tensor>({{common.params(inputs, attrs, False)}})
{%- endmacro -%}

{% macro body(name, inputs, attrs, outputs) %}
Expand All @@ -35,7 +35,13 @@ return ::{{name}}_ad_func({{common.args(input_names, attr_names)}});

{% for api in apis %}
{%- if api.is_prim and api.name in backend_white_list -%}
{{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}} {
{% set inplace_map = {} %}
{% if 'inplace' in api and api.inplace != None %}
{% for source, target in api.inplace.items() %}
{% do inplace_map.update({source: target}) %}
{% endfor %}
{% endif %}
{{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate, inplace_map)}} {
{{body(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace backend {

using LazyTensor = paddle::primitive::LazyTensor;

{%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) -%}
{%- macro sig(name, inputs, outputs, attrs, inplace_map, mutable_attribute_as_inputs=False) -%}
template <>
{{common.ret(outputs)}} {{name}}<LazyTensor>({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}})
{{common.ret(outputs, inplace_map)}} {{name}}<LazyTensor>({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}})
{%- endmacro -%}

{%- macro prepare_ir_api_inputs(inputs)-%}
Expand All @@ -25,46 +25,80 @@ std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_res.
return std::static_pointer_cast<LazyTensor>(t.impl())->value();
});
{% elif input.typename=='Tensor[]' and input.optional %}
std::vector<pir::Value> {{input.name}}_res({{input.name}}.size());
paddle::optional<std::vector<pir::Value>> {{input.name}}_res;
if({{input.name}}) {
std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_res.begin(), [](const Tensor& t) {
std::vector<pir::Value> {{input.name}}_res_inner({{input.name}}.get().size());
std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_res_inner.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->value();
});
{{input.name}}_res = paddle::make_optional<std::vector<pir::Value>>({{input.name}}_res_inner);
}
{% elif input.typename=='Tensor' and not input.optional %}
pir::Value {{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.impl())->value();
{% else %}
pir::Value {{input.name}}_res;
paddle::optional<pir::Value> {{input.name}}_res;
if({{input.name}}) {
{{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.get().impl())->value();
pir::Value {{input.name}}_res_inner;
{{input.name}}_res_inner = std::static_pointer_cast<LazyTensor>({{input.name}}.get().impl())->value();
{{input.name}}_res = paddle::make_optional<pir::Value>({{input.name}}_res_inner);
}
{% endif %}
{% endfor %}
{%- endmacro -%}

{%- macro get_static_backend_outputs(outputs)-%}
{%- if outputs|length == 1 -%}
{%- if outputs[0].typename == 'Tensor' -%}
{%- if outputs[0].typename == 'Tensor' and not outputs[0].optional-%}
Tensor {{outputs[0].name}}(std::make_shared<LazyTensor>(op_res));
return {{outputs[0].name}};
{%- elif outputs[0].typename == 'Tensor[]' -%}
{%- elif outputs[0].typename == 'Tensor' and outputs[0].optional -%}
paddle::optional<Tensor> {{outputs[0].name}};
if(op_res){
{{outputs[0].name}} = paddle::make_optional<Tensor>(Tensor(std::make_shared<LazyTensor>(op_res.get()));
}
return {{outputs[0].name}};
{%- elif outputs[0].typename == 'Tensor[]' and not outputs[0].optional -%}
std::vector<Tensor> {{outputs[0].name}}(op_res.size());
std::transform(op_res.begin(), op_res.end(), {{outputs[0].name}}.begin(), [](const pir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
return {{outputs[0].name}};
{%- elif outputs[0].typename == 'Tensor[]' and outputs[0].optional -%}
paddle::optional<std::vector<Tensor>> {{outputs[0].name}};
if({{op_res}}) {
std::vector<pir::Value> {{outputs[0].name}}_inner(op_res.get().size());
std::transform(op_res.get().begin(), op_res.get().end(), {{outputs[0].name}}_inner.begin(), [](const pir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
{{outputs[0].name}} = paddle::make_optional<std::vector<Tensor>>({{outputs[0].name}}_inner);
}
return {{outputs[0].name}};
{%- else -%} {#- render nothing -#}
{%- endif -%}
{%- elif outputs|length > 1 -%}
{%- for i in range(outputs|length) %}
auto op_res_{{i}} = std::get<{{i}}>(op_res);
{% if outputs[i].typename == 'Tensor' %}
{% if outputs[i].typename == 'Tensor' and not outputs[i].optional %}
Tensor {{outputs[i].name}}(std::make_shared<LazyTensor>(op_res_{{i}}));
{% elif outputs[i].typename == 'Tensor[]' %}
{% elif outputs[i].typename == 'Tensor' and outputs[i].optional %}
paddle::optional<Tensor> {{outputs[i].name}};
if(op_res_{{i}}){
{{outputs[i].name}} = paddle::make_optional<Tensor>(Tensor(std::make_shared<LazyTensor>(op_res_{{i}}.get())));
}
{% elif outputs[i].typename == 'Tensor[]' and not outputs[i].optional %}
std::vector<Tensor> {{outputs[i].name}}(op_res_{{i}}.size());
std::transform(op_res_{{i}}.begin(), op_res_{{i}}.end(), {{outputs[i].name}}.begin(), [](const pir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
{% elif outputs[i].typename == 'Tensor[]' and outputs[i].optional %}
paddle::optional<std::vector<Tensor>> {{outputs[i].name}};
if(op_res_{{i}}){
std::vector<Tensor> {{outputs[i].name}}_inner(op_res_{{i}}.get().size());
std::transform(op_res_{{i}}.get().begin(), op_res_{{i}}.get().end(), {{outputs[i].name}}_inner.begin(), [](const pir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
{{outputs[i].name}} = paddle::make_optional<std::vector<Tensor>>({{outputs[i].name}}_inner);
}
{% else %} {#- render nothing -#}
{% endif %}
{% endfor -%}
Expand Down Expand Up @@ -107,14 +141,20 @@ auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}}
{% for api in apis %}
{% if api.name in backend_white_list %}
{% set api_outputs = api.outputs | trip_intermediate %}
{{sig(api.name, api.inputs, api_outputs, api.attrs)}} {
{% set inplace_map = {} %}
{% if 'inplace' in api and api.inplace != None %}
{% for source, target in api.inplace.items() %}
{% do inplace_map.update({source: target}) %}
{% endfor %}
{% endif %}
{{sig(api.name, api.inputs, api_outputs, api.attrs, inplace_map)}} {
{% filter indent(2, True) %}
{{body(api.name, api.inputs, api_outputs, api.attrs)}}
{% endfilter %}
}

{% if api.attrs is exist_mutable_attribute %}
{{sig(api.name, api.inputs, api_outputs, api.attrs, True)}} {
{{sig(api.name, api.inputs, api_outputs, api.attrs, inplace_map, True)}} {
{% filter indent(2, True) %}
{{body(api.name, api.inputs, api_outputs, api.attrs, True)}}
{% endfilter %}
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/primitive/codegen/templates/common.j2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False, default=False) -%}
{%- macro sig(name, inputs, outputs, attrs, inplace_map, mutable_attribute_as_inputs=False, default=False) -%}
template <typename T>
{{ret(outputs)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}})
{{ret(outputs, inplace_map)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}})
{%- endmacro %}


Expand Down Expand Up @@ -40,9 +40,9 @@ template <typename T>
{%- endmacro -%}


{%- macro ret(outputs) -%}
{%- macro ret(outputs, inplace_map) -%}
{%- set names = [] -%}
{%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type) -%} {%- endfor -%}
{%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type(i.name in inplace_map and i.optional)) -%} {%- endfor -%}
{%- if names|length > 1 -%}
std::tuple<{{sequence('', '', ', ', names)}}>
{%- else -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ using IntArray = paddle::experimental::IntArray;
{%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%}
{%- set attr_names = [] -%}
{%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {% endfor %}
{{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, False, True)}} {
{% set inplace_map = {} %}
{% if 'inplace' in api and api.inplace != None %}
{% for source, target in api.inplace.items() %}
{% do inplace_map.update({source: target}) %}
{% endfor %}
{% endif %}
{{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, inplace_map, False, True)}} {
return backend::{{api.name}}<T>({{common.args(input_names, attr_names)}});
}

Expand Down

0 comments on commit 0932142

Please sign in to comment.