Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Support non-tensor type in input_spec #33464

Merged
merged 6 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,8 @@ def _verify_input_spec(self, input_spec):
raise TypeError(
"The type(input_spec) should be one of (tuple, list), but received {}.".
format(type_name(input_spec)))
input_spec = tuple(input_spec)
for spec in flatten(input_spec):
if not isinstance(spec, paddle.static.InputSpec):
raise ValueError(
"The type(elem) from input_spec should be `InputSpec`, but received {}.".
format(type_name(spec)))

return input_spec
return tuple(input_spec)

def __repr__(self):
return "function: {}({}), input_spec: {}".format(
Expand Down Expand Up @@ -326,9 +320,8 @@ def check_type_and_len(input, spec, check_length=False):
elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec
else:
raise TypeError(
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
type_name(input_spec))
# NOTE(Aurelius84): Support non-Tensor type as input spec info
return input_spec


def replace_spec_empty_name(args_name, input_with_spec):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import six
import textwrap
import threading
import warnings
import weakref

from paddle.fluid import framework
Expand Down Expand Up @@ -314,7 +313,7 @@ def __call__(self, *args, **kwargs):
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
# display this warning message only once.
warnings.warn(
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)")
Expand Down Expand Up @@ -481,6 +480,10 @@ def concrete_program_specify_input_spec(self, input_spec=None):
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec
if input_spec is not None:
logging_utils.warn(
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".
format(desired_input_spec, input_spec))

has_input_spec = (desired_input_spec is not None)
if has_input_spec:
Expand Down Expand Up @@ -886,7 +889,7 @@ def func(x):
if not self.enable_to_static:
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**.
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
Expand Down
86 changes: 63 additions & 23 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import textwrap
import numpy as np

import paddle
from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype

Expand Down Expand Up @@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None):
"""
Makes input `x` hashable.

For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values.
For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
"""
if isinstance(x, (tuple, list)):
if isinstance(x, (tuple, list, set)):
return tuple(map(make_hashable, x))

try:
Expand Down Expand Up @@ -1421,10 +1422,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
Returns True if the two input specs are compatible, otherwise False.

args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec or int/str et.al
desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec or int/str et.al
"""
len_specs = len(src_input_specs)
if len_specs != len(desired_input_specs):
Expand All @@ -1433,30 +1434,69 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
for spec in src_input_specs:
if spec not in desired_input_specs:
return False

else:
for i in range(len_specs):
src_shape = src_input_specs[i].shape
other_shape = desired_input_specs[i].shape
len_shape = len(src_shape)
if len_shape != len(other_shape):
return False
for j in range(len_shape):
if src_shape[j] is None or src_shape[j] < 0:
continue
if other_shape[j] is None or other_shape[j] < 0:
continue
if src_shape[j] != other_shape[j]:
for (src_spec, desired_spec) in zip(src_input_specs,
desired_input_specs):
if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
desired_spec, paddle.static.InputSpec):
if not _compatible_tensor_spec(src_spec, desired_spec):
return False
else:
if not _compatible_non_tensor_spec(src_spec, desired_spec):
return False

src_dtype = convert_dtype(src_input_specs[i].dtype)
other_dtype = convert_dtype(desired_input_specs[i].dtype)
if src_dtype != other_dtype:
return False
return True


def _compatible_tensor_spec(src_spec, desired_spec):
"""
Check whether two tensor type spec is compatible.
"""
for spec in [src_spec, desired_spec]:
if not isinstance(spec, paddle.static.InputSpec):
return False
src_shape = src_spec.shape
other_shape = desired_spec.shape
len_shape = len(src_shape)
if len_shape != len(other_shape):
return False
for j in range(len_shape):
if src_shape[j] is None or src_shape[j] < 0:
continue
if other_shape[j] is None or other_shape[j] < 0:
continue
if src_shape[j] != other_shape[j]:
return False

src_dtype = convert_dtype(src_spec.dtype)
other_dtype = convert_dtype(desired_spec.dtype)
if src_dtype != other_dtype:
return False

return True


def _compatible_non_tensor_spec(src_spec, desired_spec):
"""
Check whether two non-tensor type spec is compatible.
"""

def hash_value(spec):
try:
hash_val = make_hashable(spec)
except:
hash_val = None
return hash_val

src_hash_val = hash_value(src_spec)
desired_hash_val = hash_value(desired_spec)

if src_hash_val != desired_hash_val:
return False
else:
return True


def slice_is_num(slice_node):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
Expand Down
23 changes: 15 additions & 8 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec):
]
if input_spec is None:
# no prune
result_list = input_var_names
elif input_spec is not None and len(input_spec) == len(input_var_names):
return input_var_names
else:
# fileter out non-tensor type spec infos.
input_spec = [
spec for spec in input_spec
if isinstance(spec, paddle.static.InputSpec)
]

if len(input_spec) == len(input_var_names):
# no prune
result_list = input_var_names
# if input spec name not in input_var_names, only raise warning
Expand Down Expand Up @@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs):
Args:
layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of
input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
such as int, float, string, or list/dict of them.If None, all input variables of
the original Layer's forward method would be the inputs of the saved model. Default None.
**configs (dict, optional): Other save configuration options for compatibility. We do not
recommend using these configurations, they may be removed in the future. If not necessary,
Expand Down Expand Up @@ -698,9 +706,8 @@ def fun(inputs):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
# NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
inner_input_spec.append(var)

# parse configs
configs = _parse_save_configs(configs)
Expand All @@ -719,7 +726,7 @@ def fun(inputs):
inner_input_spec)
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same sturcture
# inner_input_spec is list[InputSpec], it should be packed with same structure
# as original input_spec here.
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def test_verify_input_spec(self):
with self.assertRaises(TypeError):
foo_spec = FunctionSpec(foo_func, input_spec=a_spec)

# each element of input_spec should be `InputSpec`
with self.assertRaises(ValueError):
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10])

foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
self.assertTrue(len(foo_spec.flat_input_spec) == 2)

Expand Down
Loading