-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove Input requirement in dygraph for Model #27557
Remove Input requirement in dygraph for Model #27557
Conversation
… remove-input-requirment-in-dygraph-CModel
Thanks for your contribution! |
… remove-input-requirment-in-dygraph-CModel
… remove-input-requirment-in-dygraph-CModel
… remove-input-requirment-in-dygraph-CModel
python/paddle/hapi/model.py
Outdated
if isinstance(inputs, list): | ||
self._shapes = [list(input.shape) for input in inputs] | ||
elif isinstance(inputs, dict): | ||
self._shapes = [list(inputs[name]) for name in inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否是inputs[name].shape
呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,十分感谢~已经修改
python/paddle/hapi/model.py
Outdated
for i in range(len(data) - 1) | ||
] | ||
self._is_shape_inferred = True | ||
self._inputs = self._verify_spec(None, self._shapes, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_run_one_epoch
会调用train_batch
等方法,应该不用在这里也另外实现这个功能
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢指出,现已改正。会更仔细些
… remove-input-requirment-in-dygraph-CModel
python/paddle/hapi/model.py
Outdated
@@ -598,6 +598,7 @@ def __init__(self, model): | |||
'test_batch': 0 | |||
} | |||
|
|||
self._shapes = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议使用一个更达义的变量名,比如self._input_shapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢指出,已修改
@@ -844,14 +849,21 @@ def __init__(self, network, inputs=None, labels=None): | |||
self._loss = None | |||
self._loss_weights = None | |||
self._optimizer = None | |||
self._optimizer = None | |||
self._shapes = None | |||
self._is_shape_inferred = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议一个更达义的变量名替换self._shapes
。这里如果成员变量inputs
初始化后是不变的话,是否不需要一个额外的shape
变量。只需要一个成员函数解析一下self._inputs
即可?
如果保留这个shape变量的话,建议修改下变量名。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动、静态图下在Model
初始化时都需要对self._inputs
进行初始化,因为目前train_batch
, eval_batch
等也需要用到self._inputs
。因此需要在动态图下用户没提供inputs
时用self._input_shapes
记录下在运行模型推导出的输入shape,以便能通过此次更新后的self._verify_spec
根据shape获取一个可传递给paddle.to_static
的较为合理的self._inputs
self._shapes = [list(input.shape) for input in inputs] | ||
elif isinstance(inputs, dict): | ||
self._shapes = [list(inputs[name].shape) for name in inputs] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如上个comment,elif
的这部分逻辑其实可以抽离到一个成员函数里,用于解析inputs
的shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,感谢~
@@ -947,7 +964,12 @@ def eval_batch(self, inputs, labels=None): | |||
loss = model.eval_batch([data], [label]) | |||
print(loss) | |||
""" | |||
return self._adapter.eval_batch(inputs, labels) | |||
loss = self._adapter.eval_batch(inputs, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢~已修正
@@ -987,7 +1009,12 @@ def test_batch(self, inputs): | |||
out = model.test_batch([data]) | |||
print(out) | |||
""" | |||
return self._adapter.test_batch(inputs) | |||
loss = self._adapter.test_batch(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,感谢
python/paddle/hapi/model.py
Outdated
@@ -1677,6 +1704,14 @@ def get_inout_spec(all_vars, return_name=False): | |||
if fluid.in_dygraph_mode(): | |||
with fluid.framework._dygraph_guard(None): | |||
layer = self.network | |||
if self._shapes is None: # No provided or inferred | |||
raise RuntimeError( | |||
"Saving inference model needs `inputs` or running before saving." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错文案可以优化下。这里只给出了原因,也可以给出一些解决方式。比如:
- 提示用户指定inputs
- 可以输入数据,执行一次训练用于shape的推导。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢!已修改,在报错信息里进行了更详细的说明
python/paddle/hapi/model.py
Outdated
) | ||
if self._is_shape_inferred: | ||
warnings.warn( | ||
'Saving actual input shapes only if `inputs` is provided, otherwise variable input dimension is immutable.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个报错文案是否可以优化下。这里warning信息里有两点要提示用户:
- 提醒用户Model里没有指定inputs,这里保存时将使用从实际数据推导出来的shape信息保存模型
- 打印出推导出来的shape信息,让用户知道自己保存的输入shape
for i, n in enumerate(arg_names) | ||
] | ||
else: | ||
out_specs = [Input(name=n, shape=[None]) for n in arg_names] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的else
分支应该可以去掉吧。直接返回空就可以了,后续依靠推导来得到。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train_batch
, eval_batch
目前会用到self._inputs
,因此Model
初始化时所初始化的self._inputs
不能是None
。
… remove-input-requirment-in-dygraph-CModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
1.移除了动态图下
Model
对inputs
的需求, #27272(如果动态图下提供
inputs
,保存预测模型将能够保存正确的输入shape,是完全正确的状态;如果动态图下不提供inputs
也不需要保存预测模型,也不会有任何影响)2.动态图下未指定
inputs
,并需要保存预测模型时:1)未运行模型,直接保存:报错
RuntimeError
2)运行
model.fit
,model.train_batch
,model.eval_batch
,或者model.test_batch
之后可以保存预测模型,不报错,但报UserWarning,提示用户保存的不是确切的shape(不支持保存batch_size
等可变输入维度),并提示用户输入inputs
以保存更准确的input shape