Skip to content

Commit

Permalink
fix (#59456)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored Nov 29, 2023
1 parent 87353ee commit 7c92cc5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 29 deletions.
8 changes: 0 additions & 8 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from op_gen import (
OpCompatParser,
OpInfoParser,
check_need_update_ops,
to_pascal_case,
update_ops,
)

CPP_FILE_TEMPLATE = """
Expand Down Expand Up @@ -83,18 +81,12 @@ def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name):

def parse_yaml(self, op_yaml_files, op_compat_yaml_file):
op_compat_parser = OpCompatParser(op_compat_yaml_file)
need_update_ops, update_yaml_file = check_need_update_ops(op_yaml_files)

op_yaml_items = []
for yaml_file in op_yaml_files:
if update_yaml_file == yaml_file:
continue
with open(yaml_file, "r") as f:
ops = yaml.safe_load(f)
op_yaml_items = op_yaml_items + ops
# replace old ir ops with pir ops
if need_update_ops:
update_ops(op_yaml_items, update_yaml_file)

op_info_items = []
for op in op_yaml_items:
Expand Down
21 changes: 0 additions & 21 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,27 +1017,6 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):
return mutable_attribute_grad_semantics


def check_need_update_ops(op_yaml_files):
need_update_ops = False
update_yaml_file = None
for yaml_file in op_yaml_files:
if yaml_file.find("update_ops.parsed.yaml") != -1:
need_update_ops = True
update_yaml_file = yaml_file
break
return need_update_ops, update_yaml_file


def update_ops(op_yaml_items, update_yaml_file):
with open(update_yaml_file, "r") as f:
update_ops = yaml.safe_load(f)
for i in range(len(op_yaml_items)):
for update_op in update_ops:
if op_yaml_items[i]['name'] == update_op['name']:
op_yaml_items[i] = update_op
break


def OpGenerator(
op_yaml_files,
op_compat_yaml_file,
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ void BindValue(py::module *m) {
})
.def("replace_all_uses_with",
[](Value self, Value value) { self.ReplaceAllUsesWith(value); })
.def("set_type", [](Value self, Type type) { self.set_type(type); })
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
Expand Down Expand Up @@ -1305,6 +1306,21 @@ SplitedResult SplitForwardBackward(
return std::make_pair(programs, attr);
}

pir::Type CreateSelectedRowsTypeByDenseTensor(pir::Type dense_tensor_type) {
if (dense_tensor_type.isa<DenseTensorType>()) {
DenseTensorType type = dense_tensor_type.dyn_cast<DenseTensorType>();
return SelectedRowsType::get(pir::IrContext::Instance(),
type.dtype(),
type.dims(),
type.data_layout(),
type.lod(),
type.offset());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, input is not a dense tensor type."));
}
}

void BindUtils(pybind11::module *m) {
m->def("clone_program", CloneProgram);
m->def("split_program", SplitForwardBackward);
Expand All @@ -1331,6 +1347,8 @@ void BindUtils(pybind11::module *m) {
->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
pir::IrContext::Instance()->GetOrRegisterDialect<pir::ControlFlowDialect>();
});
m->def("create_selected_rows_type_by_dense_tensor",
CreateSelectedRowsTypeByDenseTensor);
m->def(
"translate_to_pir",
[](const ::paddle::framework::ProgramDesc &legacy_program) {
Expand Down
6 changes: 6 additions & 0 deletions test/ir/pir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def test_type(self):
self.assertEqual(
matmul_op.result(0).type() == add_op.result(0).type(), True
)
add_op.result(0).set_type(
paddle.base.libpaddle.pir.create_selected_rows_type_by_dense_tensor(
add_op.result(0).type()
)
)
self.assertEqual(add_op.result(0).is_selected_row_type(), True)

def test_attr(self):
main_program, start_program = (
Expand Down

0 comments on commit 7c92cc5

Please sign in to comment.