-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add diagflat op, test=develop #33334
Conversation
python/paddle/tensor/creation.py
Outdated
|
||
import paddle | ||
|
||
paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
框架现在默认是动态图运行,这个应该可以去掉。下面也是
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已去掉(两处)。
@@ -771,6 +771,125 @@ def meshgrid(*args, **kwargs): | |||
return out | |||
|
|||
|
|||
def diagflat(x, offset=0, padding_value=0, name=None): | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个接口和竞品对齐,没有padding_value=0, name=None这2个参数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已对齐,并去掉了unittest中涉及padding_value的测试用例。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了静态图的便利性,我们的api还是需要加上“name”参数的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/paddle/tensor/creation.py
Outdated
return core.ops.diag_v2(y, "offset", offset, "padding_value", | ||
padding_value) | ||
|
||
check_type(x, 'x', (Variable), 'diag_v2') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是做diagflat的检查,如果检查未通过报错,用户看到的报错信息提示的是diag_v2
会无法理解的。改为diagflat
.下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修正为diagflat。
python/paddle/tensor/creation.py
Outdated
|
||
Args: | ||
x (Tensor): The input tensor. It can be any shape. Its data type should be float32, float64, int32, int64. | ||
offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offset的说明上需要增加:Default: 0 (main diagonal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加。
import paddle | ||
import paddle.fluid as fluid | ||
from paddle.fluid import core | ||
from paddle.fluid import Program, program_guard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在不推荐用fluid的API,如果要用program_guard换为paddle.static.program_guard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已将所有涉及fluid的API替换成 paddle2.1 doc中的API。
self.outputs = {'Out': self.out} | ||
|
||
def init_config(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个接口没有对应的OP,你设置的self.op_type = "diagflat"是会报错的。不要用OpTest去检查了,换成unittest.TestCase,自己用numpy写一个函数计算期望值,和paddle.diagflat的实际结果做比较。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已去掉OP测试相关代码。
self.expected12 = np.diagflat(self.input_np4, k=-1) | ||
|
||
self.input_np5 = np.random.random(size=(2, 3, 4)).astype(np.int32) | ||
self.expected13 = np.diagflat(self.input_np5, k=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 上面的这些case太重复了,不同dtype的计算逻辑并没有任何区别。留下float64的即可。单测尽量用最小的测试集合覆盖代码中不同分支。
- 如果计算逻辑中,仅有1-D和N-D的区别,那2-D还是3-D的输入也是重复的测试样例,留一种
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.已修正,只保留了float64数据类型的测试。
2.已修正,只保留了1-D和2-D的测试样例。
self.assertTrue(np.allclose(res10, self.expected10)) | ||
self.assertTrue(np.allclose(res11, self.expected11)) | ||
self.assertTrue(np.allclose(res12, self.expected12)) | ||
self.assertTrue(np.allclose(res13, self.expected13)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 静态图还是动态图,底层调用的c++ op都是一样的,只是前端接口逻辑有差异。这里不必在静态图模式下重复验证计算精确度了。
- 可以留下一个静态图的测试样例,以保证代码覆盖率即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图和动态图在python前端组装算子(flatten+diag)时,代码实现方式不同。如果不对静态图做全面测试,若静态图组装错误,会检查不到。故最终决定留下了静态图的1-D和2-D两组测试用例。
6f53eed
to
6a43321
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
(1) Add diagflat op.
(2) Add unittest for diagflat.