-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refine param conversion logic in layer.to #36862
Refine param conversion logic in layer.to #36862
Conversation
else: | ||
size_dtype = core.size_of_dtype(dtype) | ||
waiting_alloc_memory = ( | ||
(t.numel().numpy()[0] * size_dtype) / 256 + 1) * 256 * 1.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
@@ -121,8 +121,8 @@ def __init__(self, name_scope=None, dtype="float32"): | |||
self._forward_pre_hooks = collections.OrderedDict() | |||
self._forward_post_hooks = collections.OrderedDict() | |||
|
|||
self._parameters_transform_map = {} | |||
self._buffers_transform_map = {} | |||
#self._parameters_transform_map = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* refine layer to * delete comment * refine logic * refine code * refine pure_fp16_init * refine comment
* refine layer to * delete comment * refine logic * refine code * refine pure_fp16_init * refine comment
* refine layer to * delete comment * refine logic * refine code * refine pure_fp16_init * refine comment
PR types
Performance optimization
PR changes
APIs
Describe
Class Layer
的layer.to()
用于将网络参数做device或dtype的转换。1、原始数据处理逻辑:
(1)遍历Layer的参数、梯度、buffer;
(2)拷贝到指定device:
new_t = t._copy_to(device, blocking)
;(3)数据类型转换:
new_t = new_t.cast(dtype=dtype)
;(4)返回new_t;
2、存在的问题及需求:
在fp16训练模式下,通过layer.to方法将网络参数从fp32转为fp16,利用上述逻辑在第(3)步的阶段,显存会同时持有fp32和fp16两份参数,导致显存增长。
3、修改后的逻辑:

以float32网络通过layer.to(‘float16’)将参数转为float16为例:
(1)判断显存是否足够再创建一个参数;
(2)如果显存充足,则直接将数据cast为float16。
(3)如果显存不够,则将参数拷贝到cpu并释放gpu持有的tensor,在cpu将数据cast为float16,最后将转换后的数据转回gpu;
程序运行流程图如下图所示: