-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim] fix loss of composite rule #52120
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments, can be fixed in next pr
comp_flag = (lookup_fn(op_name) is not None) and filter_(op) | ||
# Attr op_role will be set after grad op has been attached to origin op. | ||
# Currently non primitive ops in prim vjp rule will not be processed here. | ||
if op.desc.attr("op_role") == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment to show what is op_role == 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted
python/paddle/jit/dy2static/utils.py
Outdated
@@ -1478,7 +1478,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): | |||
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), | |||
): | |||
op = program_desc.block(0).op(i) | |||
if op.type() == 'fill_any_like': | |||
if op.type() in ('fill_any_like', 'fill_constant'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment to show why we can need fill_constant here, better to add flag to show only in prim we need this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments
block, | ||
filter_: typing.Callable[[framework.Operator], bool] = lambda x: True, | ||
start_idx=0, | ||
backward_length=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default value should be -1 or None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
assert ( | ||
0 <= start_idx <= length | ||
), f'expect 0 <= start_idx <= {length}, but got start_idx: {start_idx}' | ||
assert not ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if lookup_fn(op.type) is not None and filter_(op): | ||
|
||
op_name = op.type | ||
comp_flag = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad name, boolean should show its value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -1264,8 +1264,10 @@ def before_append_backward(self, forward_program): | |||
def after_append_backward(self, whole_program, backward_start_idx): | |||
backward_length = len(whole_program.block(0).ops) - backward_start_idx | |||
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: | |||
_to_prim(whole_program.blocks, whitelist=self.custom_vjps) | |||
_to_prim(whole_program.blocks, backward_length=backward_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments what does it means..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
backward_length > 0 and start_idx > 0 | ||
), f'got start_idx: {start_idx} and backward_length: {backward_length}' | ||
if backward_length > 0: | ||
idx_list = range(length - backward_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def _to_prim( | ||
blocks, | ||
blacklist=frozenset(), | ||
whitelist=frozenset(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid circular dependency. And this will be deleted later.
47d3eb5
to
2ba2af2
Compare
PR types
Others
PR changes
Others
Describe
Pcard-66969
Fix bug: when op has been attached with prim vjp rule, recursive call of composite rule will be stopped (case:_to_prim(whole_program.blocks, whitelist=self.custom_vjps) ). For example, reduce_mean is non primitive operator, but it will not be processed in composite rule of batch_norm.
Support: to_prim support processing particular part of block.