Skip to content

Commit

Permalink
[pir] Add pybind property id of OpResult (#59064)
Browse files Browse the repository at this point in the history
* add OpResult  pybind id

* remove startswith 0x
  • Loading branch information
MarioLulab authored Nov 20, 2023
1 parent d2ebb83 commit 4839766
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
14 changes: 14 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <Python.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -746,6 +747,19 @@ void BindOpResult(py::module *m) {
"persistable"));
}
})
.def_property_readonly(
"id",
[](OpResult &self) {
if (self.impl() == nullptr) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get id of OpResult whose impl "
"is not nullptr"));
} else {
std::stringstream ss;
ss << std::hex << self.impl();
return ss.str();
}
})
.def("initialized",
[](OpResult &self) {
if (self.impl() == nullptr || self.type().storage() == nullptr) {
Expand Down
8 changes: 8 additions & 0 deletions test/ir/pir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ def test_prog_seed(self):
p.global_seed(10)
self.assertEqual(p._seed, 10)

def test_opresult_id(self):
with paddle.pir_utils.IrGuard():
a = paddle.static.data(name='a', shape=[4, 4], dtype='float32')
result = paddle.tanh(a)

self.assertIsInstance(a.id, str)
self.assertIsInstance(result.id, str)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4839766

Please sign in to comment.