-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PRIM】Support custom_vjp for reducing video memory #50885
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
18cdd67
to
d75ef48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments
@@ -293,7 +316,8 @@ def _create_pure_fp16_program(self, is_infer_mode=False): | |||
@switch_to_static_graph | |||
def _create_forward_backward_train_program(self): | |||
whole_program = self._train_program | |||
_, forward_end_op_index = self._infer_info('fp32', self._create_program) | |||
# _, forward_end_op_index = self._infer_info('fp32', self._create_program) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no comments code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
class PrimHooker(PartialProgramLayerHook): | ||
def __init__(self): | ||
self.custom_vjps = set() | ||
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use _is_all_prim_enabled
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self, partial_program_layer, whole_program, backward_start_idx | ||
): | ||
backward_length = ( | ||
len(whole_program.block(0).ops) - backward_start_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments to show we need support other block later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, check block length and raise exception when the length is over 1.
to_prim(infer_program.block(0)) | ||
return infer_program | ||
|
||
partial_program = partial_program_from(concrete_program) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the only entry for dy2st?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, enable fp16 support.
python/paddle/jit/dy2static/utils.py
Outdated
@@ -1519,7 +1519,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): | |||
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), | |||
): | |||
op = program_desc.block(0).op(i) | |||
if op.type() == 'fill_any_like': | |||
if op.type() in ['fill_any_like', "fill_constant"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
|
||
@switch_to_static_graph | ||
def to_prim(blocks, exclude=frozenset()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change name maybe,too many to_prim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -99,7 +100,12 @@ def _train(self, use_prim, data, axis, keep_dim): | |||
def check_prim(self, net, use_prim): | |||
if not use_prim: | |||
return | |||
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] | |||
fwd_ops = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leave a comment or seal this into a function, in case of possible changes happened later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
74fd37a
to
156c830
Compare
@@ -183,6 +194,7 @@ def __init__( | |||
# Set default mode to train | |||
self.training = True | |||
self._infer_info = ProgramInfo() | |||
self._forward_end_index_map = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个字段可以集成到 ProgramInfo里吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个PR先合入吧,合入之后我来把这个字段放入到ProgramInfo里面。
d67a0cc
to
6191f33
Compare
…Paddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com>
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com>
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * [dy2static-ci] fix dy2static ci errors. --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some questions can be solved later
@@ -3751,7 +3751,6 @@ def __init__(self, program, idx): | |||
self.vars = collections.OrderedDict() # var_name --> var | |||
self.ops = list() # operator list | |||
self.program = program | |||
self.removed_vars = collections.OrderedDict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre-commit auto format
@@ -701,6 +736,7 @@ def _prepare_attributes(self): | |||
'program_id', | |||
self.program_id, | |||
] | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless blank?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre-commit auto format
@@ -1119,5 +1155,8 @@ def add_build_strategy_for( | |||
if hasattr(compiled_program._program, 'lr_sheduler'): | |||
builded_program.lr_sheduler = compiled_program._program.lr_sheduler | |||
else: | |||
builded_program = program | |||
# can't just create a new program, we need copy the vardesc. | |||
builded_program = paddle.static.Program() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix the bug when the program only contain a var.
@paddle.jit.to_static
def f(x):
return x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Describe
Pcard-66975