-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.46】为 Paddle gumbel_softmax 算子实现 float16 数据类型支持 #50923
Changes from 3 commits
c79ba20
2ee3f49
3eb74a6
5a6031f
5965ebf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,57 @@ def init_attrs(self): | |
self.dtype = "float64" | ||
|
||
|
||
class TestGumbelSoftmax_ZeroDim_FP16OP(OpTest): | ||
def setUp(self): | ||
self.op_type = "gumbel_softmax" | ||
self.python_api = F.gumbel_softmax | ||
self.dtype = np.float16 | ||
x = np.random.uniform(0.1, 1, []).astype(self.dtype) | ||
out = np.array(1.0).astype(self.dtype) | ||
|
||
self.inputs = {'X': x} | ||
self.outputs = {'Out': out} | ||
self.attrs = {"hard": True, "axis": -1} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(["X"], "Out") | ||
|
||
|
||
class TestGumbelSoftmaxFP16OP2(TestGumbelSoftmaxOp): | ||
def init_attrs(self): | ||
self.shape = [20, 10] | ||
self.attrs = {"hard": True, "axis": 0} | ||
self.count_expected = 10 | ||
self.dtype = np.float16 | ||
|
||
|
||
class TestGumbelSoftmaxFP16OP3(TestGumbelSoftmaxOp): | ||
def init_attrs(self): | ||
self.shape = [100] | ||
self.attrs = {"hard": True, "axis": -1} | ||
self.count_expected = 1 | ||
self.dtype = np.float16 | ||
|
||
|
||
class TestGumbelSoftmaxFP16OP4(TestGumbelSoftmaxOp): | ||
def init_attrs(self): | ||
self.shape = [20, 10, 5] | ||
self.attrs = {"hard": True, "axis": -1} | ||
self.count_expected = 200 | ||
self.dtype = np.float16 | ||
|
||
|
||
class TestGumbelSoftmaxFP16OP5(TestGumbelSoftmaxOp): | ||
def init_attrs(self): | ||
self.shape = [20, 10, 5] | ||
self.attrs = {"hard": True, "axis": 1} | ||
self.count_expected = 100 | ||
self.dtype = np.float16 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这4个单测继承TestGumbelSoftmaxFP16OP。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 您好,因为前面TestGumbelSoftmax_ZeroDim_FP16OP是针对于ZeroDim的,所以内部没有init_attrs()函数。无法更改名字为TestGumbelSoftmaxFP16OP。所以直接继承自TestGumbelSoftmaxOp。 |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FP16的单测需要参考低精度算子的单测规范进行修改: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
class TestGumbelSoftmaxOpSampleDistribution(OpTest): | ||
def softmax(self, x): | ||
x_row_max = x.max(axis=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP16的单测需要继承TestGumbelSoftmaxOp,实际上只需要为fp16的case重写init_attrs,可以减少冗余代码。
TestGumbelSoftmax_ZeroDim_FP16OP -> TestGumbelSoftmaxFP16OP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
老师,您好,这里是参考单测中原来写法。针对于ZeroDim单独继承optest进行测试,其余各test继承TestGumbelSoftmaxOp并重写init_attr()。我这里也是针对于ZeroDim单独处理了。所以直接继承了optest。后续四个test都是直接继承TestGumbelSoftmaxOp并重写init_attr()的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原始写法我想并不是最优的。TestGumbelSoftmax_ZeroDim里面其实重写init_attr也可以吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
老师您好,我尝试了直接用TestGumbelSoftmax_ZeroDim继承TestGumbelSoftmaxOp基类,但是由于基类中check_out是针对多维重写的check_out_custormized,并不适用于ZeroDim。因此我在TestGumbelSoftmax_ZeroDim中添加了init_attr方法,并令TestGumbelSoftmax_ZeroDimFP16继承修改。