-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[2.0API] Reconstruct all API related to LR Scheduler, unify dygraph and static #26550
Changes from all commits
f1522d5
8205ce0
6cb899b
801df84
146e191
38559f1
cef2787
ecfea6e
f7f000f
1252607
e5cb9fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,14 +68,16 @@ def __init__(self, | |
regularization=None, | ||
grad_clip=None, | ||
name=None): | ||
# Because of the loop import, so place it in the function body | ||
from paddle.optimizer.lr_scheduler import _LRScheduler | ||
self._parameter_list = list( | ||
parameter_list) if parameter_list is not None else None | ||
self._name = name | ||
if framework.in_dygraph_mode(): | ||
if not isinstance(learning_rate, float) and \ | ||
not isinstance(learning_rate, LearningRateDecay): | ||
if not isinstance(learning_rate, | ||
(float, LearningRateDecay, _LRScheduler)): | ||
raise TypeError( | ||
"learning rate should be float or LearningRateDecay, got %s here" | ||
"learning rate should be float or _LRScheduler, got %s here" | ||
% type(learning_rate)) | ||
if self._parameter_list is None: | ||
raise AttributeError( | ||
|
@@ -90,11 +92,11 @@ def __init__(self, | |
% regularization.__str__()) | ||
break | ||
else: | ||
if not isinstance(learning_rate, float) and \ | ||
not isinstance(learning_rate, framework.Variable): | ||
if not isinstance(learning_rate, | ||
(float, framework.Variable, _LRScheduler)): | ||
raise TypeError( | ||
"learning rate should be float or Variable, got %s here" % | ||
type(learning_rate)) | ||
"learning rate should be float or _LRScheduler, got %s here" | ||
% type(learning_rate)) | ||
|
||
if grad_clip is not None: | ||
if not isinstance(grad_clip, GradientClipBase): | ||
|
@@ -144,11 +146,15 @@ def state_dict(self): | |
state_dict = adam.state_dict() | ||
|
||
''' | ||
from paddle.optimizer.lr_scheduler import _LRScheduler | ||
state_dict = {} | ||
for k, v in self._accumulators.items(): | ||
for para_name, var_tmp in v.items(): | ||
state_dict[var_tmp.name] = var_tmp | ||
# global step if use lr decay | ||
if isinstance(self._learning_rate, _LRScheduler): | ||
state_dict["LR_Scheduler"] = self._learning_rate.state_dict() | ||
return state_dict | ||
if isinstance(self._learning_rate, LearningRateDecay): | ||
state_dict["LR_Scheduler"] = self._learning_rate.state_dict() | ||
|
||
|
@@ -192,6 +198,9 @@ def set_dict(self, state_dict): | |
adam.set_dict(opti_state_dict) | ||
|
||
''' | ||
from paddle.optimizer.lr_scheduler import _LRScheduler | ||
if isinstance(self._learning_rate, _LRScheduler): | ||
self._learning_rate.set_dict(state_dict["LR_Scheduler"]) | ||
|
||
if isinstance(self._learning_rate, LearningRateDecay): | ||
self._learning_rate.set_dict(state_dict["LR_Scheduler"]) | ||
|
@@ -252,6 +261,30 @@ def get_opti_var_name_list(self): | |
return self._opti_name_list | ||
|
||
def _create_global_learning_rate(self): | ||
from paddle.optimizer.lr_scheduler import _LRScheduler | ||
if isinstance(self._learning_rate, _LRScheduler): | ||
lr_var = self._global_learning_rate() | ||
# only create global lr_var once | ||
if not isinstance(lr_var, framework.Variable): | ||
lr_name = unique_name.generate('learning_rate') | ||
self._learning_rate._var_name = lr_name | ||
lr_var = self.helper.create_global_variable( | ||
name=lr_name, | ||
shape=[1], | ||
persistable=True, | ||
stop_gradient=True, | ||
dtype='float32' if self._dtype is None else self._dtype) | ||
main_prog = framework.default_main_program() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么是main_program, 如果不是main_program会不会有问题? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果optimizer op在哪个program就要设在放对应program里,被设置了这个属性的program会在每次executor run时,会feed相应float型学习率到对应Variable里->前向->反向->优化,跟着optimize op走的 |
||
main_prog.lr_sheduler = self._learning_rate | ||
main_prog.lr_var = lr_var | ||
self._learning_rate_map[framework.default_main_program( | ||
)] = lr_var | ||
|
||
lr_value = float(self._learning_rate()) | ||
self.helper.set_variable_initializer( | ||
lr_var, initializer=Constant(value=lr_value)) | ||
return | ||
|
||
if imperative_base.enabled(): | ||
# create learning rate Variable | ||
if isinstance(self._learning_rate, float): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么修改的是paddle.fluid.optimizer.py文件,而不是paddle.optimizer.optimizer.py文件?
1.8版本写的代码,运行的行为会发生变化。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新optimizer目前不支持大部分优化器,通知迁移优化器同学将fluid 中optimizer行为迁移到paddle optimizer中。
是做的兼容升级,1.8中不会有行为变化,但支持新的逻辑。