Skip to content

Commit

Permalink
【new ir】add ir pybind api (#55745)
Browse files Browse the repository at this point in the history
* add ir core

* add test

* modify name

* merge

* add test for __eq__

* shield  test for __eq__

* --amend

* Update new_ir_compiler.cc
  • Loading branch information
xiaoguoguo626807 authored Aug 2, 2023
1 parent 683287b commit ef29468
Show file tree
Hide file tree
Showing 19 changed files with 114 additions and 72 deletions.
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/framework/new_ir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::vector<ir::LoweredFunc> NewIRCompiler::GetOpFunc(const ::ir::Operation& op,
VLOG(4) << "GetOpFunc for op: " << op_name;
// step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand(i);
auto in_value = op.operand_source(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value));
Expand Down Expand Up @@ -215,7 +215,7 @@ std::vector<std::string> NewIRCompiler::OpGetInputNames(
std::vector<std::string> names;
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand(i);
auto value = op.operand_source(i);
std::string name = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) {
Expand Down Expand Up @@ -264,7 +264,7 @@ std::shared_ptr<Scope> BuildScope(const Target& target,

for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i);
auto in_value = (*it)->operand_source(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ PhiKernelInstruction::PhiKernelInstruction(
auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds();
std::unordered_set<::ir::Value> no_need_buffer_values;
for (size_t id = 0; id < no_need_buffer_ids.size(); id++) {
no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id]));
no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id]));
}
SetNoNeedBuffer(no_need_buffer_values);
VLOG(6) << "finish process no need buffer";
Expand Down Expand Up @@ -302,7 +302,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
variable_2_var_name) {
std::unordered_map<ir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); i++) {
ir::Value value = op->operand(i);
ir::Value value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# generator op member function

OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }}
OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand_source({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
"""
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/ir/dialect/op_generator/op_verify_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@
"""

INPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{
if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{
if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void CheckInputVars(
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
auto value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
Expand Down Expand Up @@ -298,7 +298,7 @@ void HandleForSpecialOp(
tensor_array->clear();
size_t input_num = op->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
auto value = op->operand_source(i);
PADDLE_ENFORCE_EQ(
value_2_var_name->count(value),
true,
Expand All @@ -315,7 +315,7 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>()
.AsString();

auto value = op->operand(0);
auto value = op->operand_source(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);

Expand All @@ -336,7 +336,7 @@ void HandleForSpecialOp(
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();

auto value = op->operand(0);
auto value = op->operand_source(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);

Expand Down Expand Up @@ -372,7 +372,7 @@ void HandleForSpecialOp(
if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0);
auto in_value = op->operand(0);
auto in_value = op->operand_source(0);
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
true,
phi::errors::PreconditionNotMet(
Expand Down Expand Up @@ -426,7 +426,7 @@ void HandleForInplaceOp(
if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name));
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = value_2_var_name->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
Expand Down Expand Up @@ -547,7 +547,7 @@ void BuildRuntimeContext(
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);

auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;
Expand Down Expand Up @@ -603,7 +603,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);

auto in_var_name = name_map.at(ptr);

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void BuildPhiContext(ir::Operation* op,
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.InputName2Id().at(t);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr);
Expand Down Expand Up @@ -128,7 +128,7 @@ void BuildPhiContext(ir::Operation* op,
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
ir::Value ptr = op->operand_source(name2id.at(t));

auto in_var_name = name_map.at(ptr);

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/ir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ class ConstantFoldingPattern : public ir::RewritePattern {
std::vector<ir::OpResult> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ(
op->operand(i).type().isa<paddle::dialect::DenseTensorType>(),
op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));

auto [param_name, param] = ir::GetParameterFromValue(op->operand(i));
auto [param_name, param] =
ir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name,
std::make_unique<ir::Parameter>(*param));

Expand All @@ -128,8 +129,8 @@ class ConstantFoldingPattern : public ir::RewritePattern {
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));

auto get_parameter_op =
builder.Build<ir::GetParameterOp>(param_name, op->operand(i).type());
auto get_parameter_op = builder.Build<ir::GetParameterOp>(
param_name, op->operand_source(i).type());
op_inputs.push_back(get_parameter_op->result(0));
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ phi::KernelKey GetKernelKey(
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand(in_index)).type();
auto type = map_value_pair.at(op->operand_source(in_index)).type();

if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
Expand Down Expand Up @@ -151,7 +151,7 @@ phi::KernelKey GetKernelKey(
if (op->name() == "pd.uniform") {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand(0).GetDefiningOp();
auto define_op = op->operand_source(0).GetDefiningOp();
if (define_op->name() == "pd.full_int_array") {
auto shape = define_op->attributes()
.at("value")
Expand Down Expand Up @@ -183,7 +183,7 @@ phi::KernelKey GetKernelKey(
if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) {
continue;
}
auto input_tmp = op->operand(i);
auto input_tmp = op->operand_source(i);
// NOTE: if not input_tmp, it's an optional input
if (!input_tmp) {
continue;
Expand Down Expand Up @@ -341,7 +341,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,

if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i);
auto cur_in = (*it)->operand_source(i);
if (!cur_in) {
vec_inputs.push_back(ir::OpResult());
continue;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(index).GetDefiningOp();
return op->operand_source(index).GetDefiningOp();
}

Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
Expand Down
26 changes: 21 additions & 5 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ void BindProgram(py::module *m) {
void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block");
block.def("front", &Block::front, return_value_policy::reference)
.def("get_parent_program",
[](Block &self) { return self.GetParentOp()->GetParentProgram(); })
.def("get_ops",
[](Block &self) -> py::list {
py::list op_list;
Expand All @@ -94,19 +96,22 @@ void BindBlock(py::module *m) {
void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation");
op.def("name", &Operation::name)
.def("get_parent",
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent),
return_value_policy::reference)
.def("get_parent",
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent, py::const_),
return_value_policy::reference)
.def("num_operands", &Operation::num_operands)
.def("num_results", &Operation::num_results)
.def("operand", &Operation::operand)
.def("result", &Operation::result)
.def("operand_source", &Operation::operand_source)
.def("operands",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.op_operand(i));
op_list.append(self.operand(i));
}
return op_list;
})
Expand All @@ -118,6 +123,14 @@ void BindOperation(py::module *m) {
}
return op_list;
})
.def("operands_source",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.operand_source(i));
}
return op_list;
})
.def("get_input_names",
[](Operation &self) -> py::list {
py::list op_list;
Expand Down Expand Up @@ -159,8 +172,11 @@ void BindOperation(py::module *m) {

void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value");
value.def(
"get_defining_op", &Value::GetDefiningOp, return_value_policy::reference);
value
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("__eq__", &Value::operator==);
}

void BindOpOperand(py::module *m) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) {
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->operand(idx));
op_operands.push_back(op->operand_source(idx));
}
PrintInterleave(
op_operands.begin(),
Expand All @@ -254,7 +254,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) {
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->op_operand(idx);
auto op_operand = op->operand(idx);
if (op_operand) {
op_operand_types.push_back(op_operand.type());
} else {
Expand Down
4 changes: 3 additions & 1 deletion paddle/ir/core/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class IR_API OpBase {

const AttributeMap &attributes() const { return operation()->attributes(); }

Value operand(uint32_t index) const { return operation()->operand(index); }
Value operand_source(uint32_t index) const {
return operation()->operand_source(index);
}

OpResult result(uint32_t index) const { return operation()->result(index); }

Expand Down
10 changes: 5 additions & 5 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Operation *Operation::Create(OperationArgument &&argument) {

// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
// OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
Expand Down Expand Up @@ -132,7 +132,7 @@ void Operation::Destroy() {

// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
op_operand(idx).impl()->~OpOperandImpl();
operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num =
Expand Down Expand Up @@ -186,7 +186,7 @@ ir::OpResult Operation::result(uint32_t index) const {
}
}

OpOperand Operation::op_operand(uint32_t index) const {
OpOperand Operation::operand(uint32_t index) const {
if (index >= num_operands_) {
IR_THROW("index exceeds OP input range.");
}
Expand All @@ -195,8 +195,8 @@ OpOperand Operation::op_operand(uint32_t index) const {
return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
}

Value Operation::operand(uint32_t index) const {
OpOperand val = op_operand(index);
Value Operation::operand_source(uint32_t index) const {
OpOperand val = operand(index);
return val ? val.source() : Value();
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class IR_API alignas(8) Operation final {

OpResult result(uint32_t index) const;

OpOperand op_operand(uint32_t index) const;
OpOperand operand(uint32_t index) const;

Value operand(uint32_t index) const;
Value operand_source(uint32_t index) const;

/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {

void NotifyOperationRemoved(ir::Operation* op) override {
for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand(i));
AddOperandToWorklist(op->operand_source(i));
}
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
Expand Down
Loading

0 comments on commit ef29468

Please sign in to comment.