-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【AMP OP&Test】unit test for accuracy_op #51009
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -30,7 +31,7 @@ void AccuracyRawKernel(const Context& dev_ctx, | |||
DenseTensor* total) { | |||
int* correct_data = dev_ctx.template Alloc<int>(correct); | |||
int* total_data = dev_ctx.template Alloc<int>(total); | |||
float* accuracy_data = dev_ctx.template Alloc<float>(accuracy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU的先不改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,后续删除
@@ -45,12 +45,20 @@ def setUp(self): | |||
} | |||
|
|||
def init_dtype(self): | |||
pass | |||
self.dtype = np.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这可以不改,默认是FP32的话,下面的TestAccuracyOpFp32也不用加了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个应该是本来上面设置的FP32,不知道默认是不是。我统一到init_dtype中
self.dtype = np.uint16 | ||
|
||
def test_check_output(self): | ||
self.check_output(atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1e-2是默认值,可以不做特殊设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bf16默认1e-1,保持不变
744fd52
to
b4615b1
Compare
self.check_output() | ||
|
||
|
||
def create_test_fp16_class(parent): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16的单测不用加这个,内部有判断的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后面删除
self.check_output() | ||
|
||
|
||
def create_test_fp16_class(parent): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有一个case的话,是不是不用定义这样一个类?反而变得复杂了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
globals()[cls_name] = TestAccuracyOpFp16 | ||
|
||
|
||
def create_test_bf16_class(parent): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ab280d9
to
a63d5e1
Compare
@@ -51,6 +52,14 @@ def test_check_output(self): | |||
self.check_output() | |||
|
|||
|
|||
class TestAccuracyOpFp32(TestAccuracyOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个默认就是fp32,这个case可以不用加了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for @unittest.skip
PR types
Others
PR changes
Others
Describe