Skip to content

Commit

Permalink
fix (#59460)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioLulab authored Nov 29, 2023
1 parent 61c4231 commit ede6b06
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,13 @@ void BindOperation(py::module *m) {
})
.def("get_input_names",
[](Operation &self) -> py::list {
if (self.HasInterface<paddle::dialect::OpYamlInfoInterface>() ==
false) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get input names of Operation that "
"has OpYamlInfoInterface"));
}

py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
Expand Down
24 changes: 24 additions & 0 deletions test/ir/pir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,30 @@ def test_opresult_id(self):
self.assertIsInstance(a.id, str)
self.assertIsInstance(result.id, str)

def test_operation_get_input_names_error(self):
"""It will Raise error if operation `builtin.set_parameter` calls `get_input_names`. Because `builtin.set_parameter` does not have OpYamlInfoInterface"""
with paddle.pir_utils.IrGuard():
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
param1 = paddle.pir.core.create_parameter(
dtype="float32",
shape=[5, 10],
name="param1",
initializer=paddle.nn.initializer.Uniform(),
)

block = startup.global_block()
set_parameter_ops = [
op
for op in block.ops
if op.name() == 'builtin.set_parameter'
]
set_parameter_op = set_parameter_ops[0]
parameter_name = set_parameter_op.attrs()["parameter_name"]
with self.assertRaises(ValueError):
input_names = set_parameter_op.get_input_names()


if __name__ == "__main__":
unittest.main()

0 comments on commit ede6b06

Please sign in to comment.