-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reduce_max_grad composite rule #51653
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
pr name typo |
done |
def test_check_grad(self): | ||
# only composite op support gradient check of reduce_max | ||
self.check_grad(['X'], 'Out', check_eager=True, only_check_prim=True) | ||
|
||
def test_raise_error(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if original test pass fp16, please describe it in pr describes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if original test pass fp16, please describe it in pr describes
done
@@ -271,6 +277,10 @@ def setUp(self): | |||
def test_check_output(self): | |||
self.check_output(check_eager=True) | |||
|
|||
def test_check_grad(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why original op test don't support grad check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Charles-hit if original op has no grad test how to compare with composite_grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why original op test don't support grad check
装饰器注明"reduce_max is discontinuous non-derivable function, its gradient check is not supported by unittest framework."
@@ -261,6 +266,7 @@ class TestMaxOp_ZeroDim(OpTest): | |||
|
|||
def setUp(self): | |||
self.op_type = "reduce_max" | |||
self.prim_op_type = "prim" | |||
self.python_api = paddle.max | |||
self.inputs = {'X': np.random.random([]).astype("float64")} | |||
self.attrs = {'dim': []} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better add new op test to check keep_dim = True and FP32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better add new op test to check keep_dim = True and FP32
done
@@ -233,6 +233,7 @@ class TestMaxOp(OpTest): | |||
|
|||
def setUp(self): | |||
self.op_type = "reduce_max" | |||
self.prim_op_type = "prim" | |||
self.python_api = paddle.max | |||
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} | |||
self.attrs = {'dim': [-1]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better to check keep_dim = True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better to check keep_dim = True
done
paddle/fluid/prim/api/api.yaml
Outdated
@@ -37,6 +37,8 @@ | |||
- pad | |||
- cumsum | |||
- put_along_axis | |||
- equal | |||
- where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two where, delete one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two where, delete one
done
PR types
New features
PR changes
Others
Describe
add reduce_max_grad composite rule.
becuse original op test pass fp16, composite op test pass fp16 too.