Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【new ir】add ir pybind api #55745

Merged
merged 11 commits into from
Aug 2, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
variable_2_var_name) {
std::unordered_map<ir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); i++) {
ir::Value value = op->operand(i);
ir::Value value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# generator op member function

OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }}
OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand_source({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
"""
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/ir/dialect/op_generator/op_verify_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@
"""

INPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{
if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{
if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void CheckInputVars(
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
auto value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
Expand Down Expand Up @@ -250,7 +250,7 @@ void HandleForSpecialOp(
tensor_array->clear();
size_t input_num = op->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
auto value = op->operand_source(i);
PADDLE_ENFORCE_EQ(
value_2_var_name->count(value),
true,
Expand All @@ -267,7 +267,7 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>()
.AsString();

auto value = op->operand(0);
auto value = op->operand_source(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);

Expand All @@ -283,7 +283,7 @@ void HandleForSpecialOp(
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();

auto value = op->operand(0);
auto value = op->operand_source(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);

Expand All @@ -307,7 +307,7 @@ void HandleForSpecialOp(
if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0);
auto in_value = op->operand(0);
auto in_value = op->operand_source(0);
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
true,
phi::errors::PreconditionNotMet(
Expand Down Expand Up @@ -361,7 +361,7 @@ void HandleForInplaceOp(
if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name));
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = value_2_var_name->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
Expand Down Expand Up @@ -482,7 +482,7 @@ void BuildRuntimeContext(
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);

auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;
Expand Down Expand Up @@ -538,7 +538,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);

auto in_var_name = name_map.at(ptr);

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void BuildPhiContext(ir::Operation* op,
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.InputName2Id().at(t);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr);
Expand Down Expand Up @@ -128,7 +128,7 @@ void BuildPhiContext(ir::Operation* op,
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
ir::Value ptr = op->operand_source(name2id.at(t));

auto in_var_name = name_map.at(ptr);

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/ir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ class ConstantFoldingPattern : public ir::RewritePattern {
std::vector<ir::OpResult> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ(
op->operand(i).type().isa<paddle::dialect::DenseTensorType>(),
op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));

auto [param_name, param] = ir::GetParameterFromValue(op->operand(i));
auto [param_name, param] =
ir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name,
std::make_unique<ir::Parameter>(*param));

Expand All @@ -128,8 +129,8 @@ class ConstantFoldingPattern : public ir::RewritePattern {
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));

auto get_parameter_op =
builder.Build<ir::GetParameterOp>(param_name, op->operand(i).type());
auto get_parameter_op = builder.Build<ir::GetParameterOp>(
param_name, op->operand_source(i).type());
op_inputs.push_back(get_parameter_op->result(0));
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ phi::KernelKey GetKernelKey(
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand(in_index)).type();
auto type = map_value_pair.at(op->operand_source(in_index)).type();

if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
Expand Down Expand Up @@ -151,7 +151,7 @@ phi::KernelKey GetKernelKey(
if (op->name() == "pd.uniform") {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand(0).GetDefiningOp();
auto define_op = op->operand_source(0).GetDefiningOp();
if (define_op->name() == "pd.full_int_array") {
auto shape = define_op->attributes()
.at("value")
Expand Down Expand Up @@ -183,7 +183,7 @@ phi::KernelKey GetKernelKey(
if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) {
continue;
}
auto input_tmp = op->operand(i);
auto input_tmp = op->operand_source(i);
// NOTE: if not input_tmp, it's an optional input
if (!input_tmp) {
continue;
Expand Down Expand Up @@ -343,7 +343,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,

if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i);
auto cur_in = (*it)->operand_source(i);
if (!cur_in) {
vec_inputs.push_back(ir::OpResult());
continue;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(index).GetDefiningOp();
return op->operand_source(index).GetDefiningOp();
}

Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
Expand Down
19 changes: 17 additions & 2 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ void BindProgram(py::module *m) {
void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block");
block.def("front", &Block::front, return_value_policy::reference)
.def("get_parent_program",
[](Block &self) { return self.GetParentOp()->GetParentProgram(); })
.def("get_ops",
[](Block &self) -> py::list {
py::list op_list;
Expand All @@ -81,14 +83,19 @@ void BindBlock(py::module *m) {
void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation");
op.def("name", &Operation::name)
.def("get_parent", &Operation::GetParent, return_value_policy::reference)
.def("get_parent_block",
&Operation::GetParent,
return_value_policy::reference)
.def("num_operands", &Operation::num_operands)
.def("num_results", &Operation::num_results)
.def("operand", &Operation::operand)
.def("result", &Operation::result)
.def("operand_source", &Operation::operand_source)
.def("operands",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.op_operand(i));
op_list.append(self.operand(i));
}
return op_list;
})
Expand All @@ -100,6 +107,14 @@ void BindOperation(py::module *m) {
}
return op_list;
})
.def("operands_source",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.operand_source(i));
}
return op_list;
})
.def("get_input_names",
[](Operation &self) -> py::list {
py::list op_list;
Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) {
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->operand(idx));
op_operands.push_back(op->operand_source(idx));
}
PrintInterleave(
op_operands.begin(),
Expand All @@ -254,7 +254,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) {
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->op_operand(idx);
auto op_operand = op->operand(idx);
if (op_operand) {
op_operand_types.push_back(op_operand.type());
} else {
Expand Down
4 changes: 3 additions & 1 deletion paddle/ir/core/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class IR_API OpBase {

const AttributeMap &attributes() const { return operation()->attributes(); }

Value operand(uint32_t index) const { return operation()->operand(index); }
Value operand_source(uint32_t index) const {
return operation()->operand_source(index);
}

OpResult result(uint32_t index) const { return operation()->result(index); }

Expand Down
10 changes: 5 additions & 5 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Operation *Operation::Create(OperationArgument &&argument) {

// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
// OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
Expand Down Expand Up @@ -132,7 +132,7 @@ void Operation::Destroy() {

// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
op_operand(idx).impl()->~OpOperandImpl();
operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num =
Expand Down Expand Up @@ -186,7 +186,7 @@ ir::OpResult Operation::result(uint32_t index) const {
}
}

OpOperand Operation::op_operand(uint32_t index) const {
OpOperand Operation::operand(uint32_t index) const {
if (index >= num_operands_) {
IR_THROW("index exceeds OP input range.");
}
Expand All @@ -195,8 +195,8 @@ OpOperand Operation::op_operand(uint32_t index) const {
return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
}

Value Operation::operand(uint32_t index) const {
OpOperand val = op_operand(index);
Value Operation::operand_source(uint32_t index) const {
OpOperand val = operand(index);
return val ? val.source() : Value();
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class IR_API alignas(8) Operation final {

OpResult result(uint32_t index) const;

OpOperand op_operand(uint32_t index) const;
OpOperand operand(uint32_t index) const;

Value operand(uint32_t index) const;
Value operand_source(uint32_t index) const;

/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {

void NotifyOperationRemoved(ir::Operation* op) override {
for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand(i));
AddOperandToWorklist(op->operand_source(i));
}
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/ir/core/ir_program_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ TEST(program_test, program) {
// (8) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");

EXPECT_EQ(op4->op_operand(0).type().dialect().id(), paddle_dialect->id());
EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface =
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>();
op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/ir/core/ir_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ TEST(value_test, value_test) {

// Test 2: op1_first_output -> op4_first_input
ir::OpResult op1_first_output = op1->result(0);
ir::OpOperand op4_first_input = op4->op_operand(0);
ir::OpOperand op4_first_input = op4->operand(0);

EXPECT_EQ(op1_first_output.first_use(), op4_first_input);
ir::OpOperand op3_first_input = op3->op_operand(0);
ir::OpOperand op3_first_input = op3->operand(0);

EXPECT_EQ(op4_first_input.next_use(), op3_first_input);
EXPECT_EQ(op3_first_input.next_use(), nullptr);
Expand All @@ -110,11 +110,11 @@ TEST(value_test, value_test) {
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
//
c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1), b);
EXPECT_EQ(op4->operand_source(1), b);
EXPECT_TRUE(c.use_empty());

b.ReplaceAllUsesWith(a);
EXPECT_EQ(op4->operand(1), a);
EXPECT_EQ(op4->operand_source(1), a);
EXPECT_TRUE(b.use_empty());

// destroy
Expand Down
Loading