-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR API adaptor No.14-15】Assign and Bilinear #58876
Conversation
Sorry to inform you that b58508d's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice work ~
test/legacy_test/test_initializer.py
Outdated
@@ -541,6 +542,7 @@ def test_msra_initializer_bf16(self): | |||
|
|||
|
|||
class TestBilinearInitializer(unittest.TestCase): | |||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initializer 相关的 pir 单测无法适配旧 ir,所以需要单独新增 pir 下的单测,可以参考:#59419
test/legacy_test/test_initializer.py
Outdated
@@ -628,6 +630,7 @@ def test_bilinear_initializer_fp16(self): | |||
|
|||
|
|||
class TestNumpyArrayInitializer(unittest.TestCase): | |||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
test/legacy_test/test_initializer.py
Outdated
with paddle.static.program_guard(main, startup): | ||
param = paddle.pir.core.create_parameter( | ||
dtype=dtype, | ||
shape=[5, 10], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape 应与 TestBilinearInitializer.test_bilinear_initializer 单测对齐
test/legacy_test/test_initializer.py
Outdated
num_ops = 2 if dtype in ["float16", "uint16"] else 1 | ||
self.assertEqual(len(checked_ops), num_ops) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理,也不需要做 op 数量的判断
test/legacy_test/test_initializer.py
Outdated
num_ops = 2 if dtype in ["float16", "uint16"] else 1 | ||
self.assertEqual(len(checked_ops), num_ops) | ||
init_op = checked_ops[0] | ||
self.assertEqual(init_op.type, 'assign_value') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用再去检查 init_op.type,因为前面的 self.get_init_ops_by_op_name 就是根据 op_name 去获取 op 的。这里改成
self.assertEqual(len(checked_ops), 1)
即可
test/legacy_test/test_initializer.py
Outdated
def test_numpy_array_initializer_bf16(self): | ||
"""Test the numpy array initializer with bfloat16""" | ||
block = self.test_numpy_array_initializer("uint16") | ||
self.assertTrue(block.ops[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行不需要了
Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com>
PR types
Others
PR changes
APIs
Description
#58067 14-15