-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hybrid] optimizer sharding support optimize cast #35878
[hybrid] optimizer sharding support optimize cast #35878
Conversation
Thanks for your contribution! |
09537d3
to
044ec8b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
startup_block.append_op( | ||
type='c_sync_comm_stream', | ||
inputs={'X': broadcast_params}, | ||
outputs={'Out': broadcast_params}, | ||
inputs={'X': params_name}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the broadcast in launched into calc stream, there is not need to sync calc stream at the end of broadcasts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I originally wanted to delete it in this PR, but there are too many unittest that need to be changed, so I kept it first... will remove in future.
# param is only used by cast op, | ||
# which to cast fp32_param to fp16_param | ||
output_name = op.output_arg_names[0] | ||
if 'cast_fp16' not in output_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use a global variable to record the 'cast_fp16' rule, otherwise if this pattern is change in AMP, we should change everywhere in sharding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get, good idea
offload_helper.cast_fp32param_in_optimize(main_block, startup_block) | ||
offload_helper = OffloadHelper(ring_id=dp_ring_id) | ||
if self._optimizer_sharding: | ||
offload_helper.opt_sharding_cast_fp32param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job, not only reduce the number of cast op from twice per param to once per param, but also reduce the frequency of cast call to 1/acc_step !
PR types
Performance optimization
PR changes
Others
Describe
optimizer sharding support optimize cast.
精度测试
Ernie3.0,base模型,单机8卡
baseline=2mp+2pp+2dp, optimize_cast=2mp+2pp+2opt_sharding+optimize_cast
速度测试