-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Prim][NewIR] Support forward decomposition in new IR (#55480)
* Support Prim Forward in New IR * Fix test case * polish code * fix code * polish format * format code
- Loading branch information
1 parent
a5ba0b6
commit 523916f
Showing
11 changed files
with
461 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .decomp import decompose # noqa: F401 | ||
from . import rules # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import typing | ||
|
||
from paddle import ir | ||
from paddle.fluid.libpaddle.ir import Block, Program | ||
from paddle.framework import core | ||
|
||
from . import register | ||
|
||
|
||
def _build_tensor_tuple(xs): | ||
if isinstance(xs, ir.OpResult): | ||
return (xs,) | ||
elif isinstance(xs, typing.Sequence): | ||
return tuple(xs) | ||
return TypeError(f"Type {type(xs)} is not supported") | ||
|
||
|
||
def _prepare_python_api_arguments(op): | ||
""" | ||
For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs. | ||
Args: | ||
op (Operator): The target operator. | ||
""" | ||
op_inputs = [x.source() for x in op.operands()] | ||
op_attrs_dict = op.attrs() | ||
op_attrs_name = op.get_attr_names() | ||
op_attrs = [op_attrs_dict[x] for x in op_attrs_name] | ||
api_arguments = op_inputs + op_attrs | ||
return tuple(api_arguments) | ||
|
||
|
||
def _check_op_results(op_name, orig_outs, new_outs): | ||
""" | ||
Check whether the replaced outputs are consistent with origin outputs. | ||
Args: | ||
op_name (str): The name of operator. | ||
orig_outs (tuple): The outputs of original operator. | ||
new_outs (tuple): The outputs of replaced operator. | ||
""" | ||
assert len(orig_outs) == len(new_outs), ( | ||
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' | ||
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' | ||
) | ||
|
||
for orig_out, new_out in zip( | ||
orig_outs, | ||
new_outs, | ||
): | ||
if (orig_out is None or new_out is None) and ( | ||
op_name not in core.ops_contain_none | ||
): | ||
raise ValueError( | ||
f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}" | ||
) | ||
if orig_out is None: | ||
# to keep same as phi op definition, orig_out may receive None | ||
continue | ||
elif new_out is not None: | ||
orig_dtype = orig_out.dtype | ||
new_dtype = new_out.dtype | ||
orig_shape = orig_out.shape | ||
new_shape = new_out.shape | ||
assert orig_dtype == new_dtype, ( | ||
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, ' | ||
f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}' | ||
) | ||
assert ( | ||
-1 not in new_shape | ||
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.' | ||
assert orig_shape == new_shape, ( | ||
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' | ||
f'but orig_out shape={orig_shape} and new_out shape={new_shape}' | ||
) | ||
assert not (orig_out is None) ^ ( | ||
new_out is None | ||
), "orig_out and new_out should match." | ||
return | ||
|
||
|
||
def decompose( | ||
program, | ||
blacklist=frozenset(), | ||
whitelist=frozenset(), | ||
): | ||
""" | ||
Search nonbasic ops which have be registered composite rules and replace them with primitive ops. | ||
The operators in blacklist will be excluded from program when decomposed into primitives, and only the | ||
operators in whitelist will be decomposed. The priority of blacklist is higher than whitelist, it means | ||
an operator both in blacklist and whitelist will not be decomposed. | ||
The finally set that will be decomposed is: | ||
(block.ops & ops have decomposite rule & whitelist) - blacklist | ||
Args: | ||
program (Program): The program to be processed. | ||
blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. | ||
whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. | ||
""" | ||
if not isinstance(program, Program): | ||
raise TypeError(f"Expect type Program, but got type {type(program)}.") | ||
block = program.block() | ||
|
||
if not isinstance(blacklist, (set, frozenset)): | ||
raise TypeError( | ||
f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.' | ||
) | ||
if not isinstance(whitelist, (set, frozenset)): | ||
raise TypeError( | ||
f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.' | ||
) | ||
|
||
blacklist = core.prim_config["forward_blacklist"] | blacklist | ||
|
||
logging.debug("Decompose composite forward ops begin...") | ||
|
||
if len(blacklist) > 0 and len(whitelist) > 0: | ||
op_filter = ( | ||
lambda x: x.name() in whitelist and x.name() not in blacklist | ||
) | ||
elif len(blacklist) > 0 and len(whitelist) == 0: | ||
op_filter = lambda x: x.name() not in blacklist | ||
elif len(blacklist) == 0 and len(whitelist) > 0: | ||
op_filter = lambda x: x.name() in whitelist | ||
else: | ||
op_filter = lambda x: True | ||
with ir.core.program_guard(program): | ||
_decompose_subgraph( | ||
block, | ||
op_filter, | ||
) | ||
logging.debug( | ||
"Decompose composite forward ops finish: {}".format( | ||
core.prim_config["composite_ops_record"] | ||
) | ||
) | ||
|
||
|
||
def _decompose_subgraph(block, op_filter): | ||
""" | ||
The operators in block wich satisfy the filter conditon will be decomposed into primitives. | ||
Args: | ||
block (Block|Sequence[Block]): The blocks of program to be processed. | ||
op_filter (function): The filter to specify which ops to be processed. | ||
""" | ||
|
||
if isinstance(block, Block): | ||
ops_list = block.get_ops() | ||
for op in ops_list: | ||
op_name = op.name() | ||
decom_rule = register.get_decomp_rule(op_name) | ||
lower = decom_rule and op_filter(op) | ||
|
||
if lower: | ||
core.prim_config["composite_ops_record"].add(op_name) | ||
input_args = _prepare_python_api_arguments(op) | ||
ir.set_insertion_point(op) | ||
orig_outs = op.results() | ||
new_outs = _build_tensor_tuple(decom_rule(*input_args)) | ||
|
||
# Todo: To cover such case: some outputs are no longer needed after decomposition. | ||
_check_op_results(op_name, orig_outs, new_outs) | ||
|
||
op.replace_all_uses_with(new_outs) | ||
block.remove_op(op) | ||
return | ||
|
||
elif isinstance(block, typing.Sequence): | ||
for item in block: | ||
_decompose_subgraph(item, op_filter) | ||
return | ||
raise TypeError( | ||
f"Expect type Block or Sequence of Block, but got type {type(block)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.tensor import abs # noqa: F401 | ||
from paddle.tensor import acos # noqa: F401 | ||
from paddle.tensor import acosh # noqa: F401 | ||
from paddle.tensor import add # noqa: F401 | ||
from paddle.tensor import asin # noqa: F401 | ||
from paddle.tensor import asinh # noqa: F401 | ||
from paddle.tensor import atan # noqa: F401 | ||
from paddle.tensor import atanh # noqa: F401 | ||
from paddle.tensor import broadcast_shape # noqa: F401 | ||
from paddle.tensor import broadcast_to # noqa: F401 | ||
from paddle.tensor import concat # noqa: F401 | ||
from paddle.tensor import cos # noqa: F401 | ||
from paddle.tensor import cosh # noqa: F401 | ||
from paddle.tensor import cumprod # noqa: F401 | ||
from paddle.tensor import cumsum # noqa: F401 | ||
from paddle.tensor import digamma # noqa: F401 | ||
from paddle.tensor import divide # noqa: F401 | ||
from paddle.tensor import erf # noqa: F401 | ||
from paddle.tensor import erfinv # noqa: F401 | ||
from paddle.tensor import exp # noqa: F401 | ||
from paddle.tensor import expm1 # noqa: F401 | ||
from paddle.tensor import fill_constant # noqa: F401 | ||
from paddle.tensor import full # noqa: F401 | ||
from paddle.tensor import gather # noqa: F401 | ||
from paddle.tensor import greater_equal # noqa: F401 | ||
from paddle.tensor import lgamma # noqa: F401 | ||
from paddle.tensor import log # noqa: F401 | ||
from paddle.tensor import log1p # noqa: F401 | ||
from paddle.tensor import logcumsumexp # noqa: F401 | ||
from paddle.tensor import logit # noqa: F401 | ||
from paddle.tensor import logsumexp # noqa: F401 | ||
from paddle.tensor import max # noqa: F401 | ||
from paddle.tensor import min # noqa: F401 | ||
from paddle.tensor import multiply # noqa: F401 | ||
from paddle.tensor import ones # noqa: F401 | ||
from paddle.tensor import pow # noqa: F401 | ||
from paddle.tensor import prod # noqa: F401 | ||
from paddle.tensor import reshape # noqa: F401 | ||
from paddle.tensor import rsqrt # noqa: F401 | ||
from paddle.tensor import sign # noqa: F401 | ||
from paddle.tensor import sin # noqa: F401 | ||
from paddle.tensor import sinh # noqa: F401 | ||
from paddle.tensor import sqrt # noqa: F401 | ||
from paddle.tensor import subtract # noqa: F401 | ||
from paddle.tensor import sum # noqa: F401 | ||
from paddle.tensor import tan # noqa: F401 | ||
from paddle.tensor import tanh # noqa: F401 | ||
from paddle.tensor import tile # noqa: F401 | ||
from paddle.tensor import uniform # noqa: F401 | ||
from paddle.tensor import zeros # noqa: F401 | ||
from paddle.tensor.creation import assign # noqa: F401 | ||
from paddle.tensor.creation import zeros_like # noqa: F401 | ||
from paddle.tensor.manipulation import cast # noqa: F401 | ||
from paddle.tensor.math import maximum # noqa: F401 | ||
from paddle.tensor.math import minimum # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import inspect | ||
|
||
|
||
class Registry: | ||
"""A general registry object.""" | ||
|
||
__slots__ = ['name', 'rules'] | ||
|
||
def __init__(self, name): | ||
self.name = name | ||
self.rules = {} | ||
|
||
def register(self, op_type, rule): | ||
assert isinstance(op_type, str) | ||
assert inspect.isfunction(rule) | ||
assert ( | ||
op_type not in self.rules | ||
), f'name "{op_type}" should not be registered before.' | ||
self.rules[op_type] = rule | ||
|
||
def lookup(self, op_type): | ||
return self.rules.get(op_type) | ||
|
||
|
||
_decomposition_ops = Registry('decomposition') | ||
|
||
|
||
def register_decomp(op_type): | ||
""" | ||
Decorator for registering the lower function for an original op into sequence of primitive ops. | ||
Args: | ||
op_type(str): The op name | ||
Returns: | ||
wrapper: Inner wrapper function | ||
Examples: | ||
.. code-block:: python | ||
@register_decomp('softmax') | ||
def softmax(x, axis): | ||
molecular = exp(x) | ||
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) | ||
res = divide(molecular, denominator) | ||
return res | ||
""" | ||
if not isinstance(op_type, str): | ||
raise TypeError(f'op_type must be str, but got {type(op_type)}.') | ||
|
||
def wrapper(f): | ||
_decomposition_ops.register(op_type, f) | ||
return f | ||
|
||
return wrapper | ||
|
||
|
||
def get_decomp_rule(op_type): | ||
_lowerrule = _decomposition_ops.lookup(op_type) | ||
return _lowerrule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .primitives import * # noqa: F403 | ||
from .register import register_decomp | ||
|
||
|
||
@register_decomp('pd.mean') | ||
def mean(x, axis, keepdim): | ||
"""define composite rule of op mean""" | ||
x_shape = x.shape | ||
axes = axis or tuple(range(0, len(x_shape))) | ||
axes = (axes,) if isinstance(axes, int) else axes | ||
sum_x = sum(x, axis=axes, keepdim=keepdim) | ||
value_to_fill = 1 | ||
for axis in axes: | ||
value_to_fill *= x_shape[axis] | ||
norm = fill_constant( | ||
shape=[], | ||
value=value_to_fill, | ||
dtype=sum_x.dtype, | ||
) | ||
res = divide(sum_x, norm) | ||
return res |
Oops, something went wrong.