-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dygraph]Add group sharded api #40129
Conversation
Thanks for your contribution! |
7da90c1
to
644c5fe
Compare
644c5fe
to
fea2316
Compare
fea2316
to
75de31e
Compare
75de31e
to
979330b
Compare
979330b
to
d89d3bc
Compare
d89d3bc
to
9ccd0bc
Compare
9ccd0bc
to
a91b2ce
Compare
a91b2ce
to
23574a9
Compare
23574a9
to
d62832f
Compare
@@ -136,7 +136,7 @@ def __init__(self, | |||
# Update optimizer parameters and adjust parameter storage and use according to rank. | |||
self._update_opt_status() | |||
|
|||
@paddle.no_grad() | |||
@fluid.dygraph.no_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里引入fluid的原因是什么?fluid下的api会被废弃。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改为paddle.autograd.no_grad()
logger_ = get_logger(logging.INFO) | ||
|
||
|
||
class ShardedLevel(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议去掉这个对象,不需要为参数定义单独增加一个对象
- 直接在group_sharded_parallel函数里使用level='os'或者直接使用level=1,参考amp的level定义,一般理解level对应一个整数,类似verbose之类的
- os, os_g, p_g_os是什么的缩写?可读性较差,是否有更好的表示方式?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过讨论,去掉ShardedLevel,采用字符串名字"os", "os_g", "p_g_os"作为level,level名字和论文对齐。
|
||
def group_sharded_parallel(model, | ||
optimizer, | ||
shard_level, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_level -> level
因为api名称已经包含sharded了,这里的参数默认都是针对shard的参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
p_g_os = 3 | ||
|
||
|
||
def group_sharded_parallel(model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除了group_sharded以外,是否还有其他的sharded方式?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前采用group_sharded的意思是分组参数切片,是和数据并行并列的一种分布式方式,所以定义为group_sharded。目前还未有其他sharded方式。
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3 | ||
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler | ||
|
||
__all__ = ['ShardedLevel', 'group_sharded_parallel', 'save_for_group_sharded'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要用__all__公开api,通过__init__.py公开就行
paddle.distributed.sharding.group_sharded_parallel
而不是
paddle.distributed.sharding.group_sharded.group_sharded_parallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
return model, optimizer, scaler | ||
|
||
|
||
def save_for_group_sharded(model, output, optimizer=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除了group外,是否还有其他的参数形式?
save_for_group_sharded -> save_sharded_model ? 或者save_group_sharded_model呢?
类似save_inference_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经讨论修改为save_group_sharded_model
d62832f
to
5e1a31a
Compare
5e1a31a
to
5a138c3
Compare
5a138c3
to
5c4621a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120)
@@ -55,6 +55,7 @@ | |||
from . import cloud_utils # noqa: F401 | |||
from . import utils # noqa: F401 | |||
|
|||
from .sharding import * # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why import *
?
PR types
New features
PR changes
APIs
Describe
Add group sharded api