Skip to content

Commit

Permalink
[PIR] add side_effect trait for dead_code_elimination pass. (PaddlePa…
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Nov 10, 2023
1 parent aafbad4 commit ff72c36
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 28 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/operators/generator/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def check_op_config(op_entry, op_name):
'composite',
'support_dygraph_mode',
'support_tensor',
'traits',
)
infer_meta_key_set = ('func', 'param', 'spmd_rule')
kernel_key_set = (
Expand Down Expand Up @@ -514,6 +515,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
else:
support_tensor = []

if "traits" in op_entry.keys():
trait_list = parse_plain_list(op_entry["traits"])
else:
trait_list = []

op = {
"name": op_name,
"inputs": inputs,
Expand All @@ -522,6 +528,7 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
"no_need_buffer": no_buffer_args,
"data_transform": data_trans,
"support_tensor": support_tensor,
"traits": trait_list,
}

# op should be is_base_op or is_invoke_op or is_only_composite_op
Expand Down
41 changes: 26 additions & 15 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/op_trait.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
Expand Down Expand Up @@ -417,6 +418,15 @@ def __init__(self, op_yaml_item, op_compat_item):
# parse forward input name list and attribute name list
self.forward_input_name_list = self.parse_forward_input_name()

# parse traits list
self.traits_list = self.parse_op_traits()

def parse_op_traits(self):
if 'traits' in self.op_yaml_item:
return self.op_yaml_item['traits']
else:
return []

def parse_forward_input_name(self):
if 'forward' in self.op_yaml_item:
forward_input_name_list = []
Expand Down Expand Up @@ -1104,7 +1114,7 @@ def OpGenerator(
op_view_map = op_info.view_map
op_data_transform_map = op_info.data_transform_map
op_interfaces = ["paddle::dialect::OpYamlInfoInterface"]
op_traits = []
op_traits = op_info.traits_list

if op_info.infer_meta_func:
op_interfaces += ["paddle::dialect::InferMetaInterface"]
Expand Down Expand Up @@ -1146,6 +1156,21 @@ def OpGenerator(
else:
op_interfaces = op_interfaces_tmp
exclusive_interface_str = exclusive_interface_str_tmp

# =================================== #
# gen interface/trait list str #
# =================================== #
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)

if op_name[-1] == "_":
op_traits += ["paddle::dialect::InplaceTrait"]

op_traits_str = ""
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)

if op_name in PD_MANUAL_OP_LIST:
continue
if op_kernel_map is None:
Expand Down Expand Up @@ -1179,20 +1204,6 @@ def OpGenerator(
kernel_func_name
]

# =================================== #
# gen interface/trait list str #
# =================================== #
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)

if op_name[-1] == "_":
op_traits += ["paddle::dialect::InplaceTrait"]

op_traits_str = ""
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)

# =================================== #
# gen get input/output methods str #
# =================================== #
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
kernel :
func : fetch
param : [x]
traits : pir::SideEffectTrait

- op : get_tensor_from_selected_rows
args : (Tensor x)
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/op_result.h"
#include "paddle/pir/core/op_trait.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/parameter.h"
#include "paddle/pir/core/program.h"
Expand Down Expand Up @@ -67,11 +68,9 @@ class ConstantFoldingPattern : public pir::RewritePattern {
}

bool Match(pir::Operation* op) const override {
if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() ||
op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::FeedOp>())
if (op->HasTrait<pir::SideEffectTrait>() ||
op->isa<pir::GetParameterOp>() || op->isa<paddle::dialect::FeedOp>())
return false;

if (!ValidOp(op)) {
return false;
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/op_trait.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/pass/pass.h"
Expand All @@ -36,10 +37,7 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
}

bool Match(pir::Operation* op) const override {
if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() ||
op->isa<pir::YieldOp>()) {
return false;
}
if (op->HasTrait<pir::SideEffectTrait>()) return false;
return op->use_empty();
}

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2317,6 +2317,7 @@
kernel:
func : shadow_output
param : [x]
traits : pir::SideEffectTrait

- op : shape
args : (Tensor input)
Expand Down
5 changes: 3 additions & 2 deletions paddle/pir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/op_trait.h"

namespace pir {

Expand Down Expand Up @@ -66,7 +67,7 @@ class IR_API GetParameterOp : public pir::Op<GetParameterOp> {
/// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute,
/// StrAttribute})
///
class IR_API SetParameterOp : public pir::Op<SetParameterOp> {
class IR_API SetParameterOp : public pir::Op<SetParameterOp, SideEffectTrait> {
public:
using Op::Op;
static const char *name() { return "builtin.set_parameter"; }
Expand All @@ -83,7 +84,7 @@ class IR_API SetParameterOp : public pir::Op<SetParameterOp> {
/// \brief ShdowOutputOp: ShdowOutputOp(OpOperand, {StrAttribute,
/// StrAttribute})
///
class IR_API ShadowOutputOp : public pir::Op<ShadowOutputOp> {
class IR_API ShadowOutputOp : public pir::Op<ShadowOutputOp, SideEffectTrait> {
public:
using Op::Op;
static const char *name() { return "builtin.shadow_output"; }
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/core/op_trait.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::OneResultTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SideEffectTrait)
7 changes: 7 additions & 0 deletions paddle/pir/core/op_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class OneResultTrait : public OpTraitBase<OneResultTrait> {
static void Verify(Operation *op);
};

///
/// \brief This trait marks the op can't be removed even if which has no output
/// or the output isn't used.
///
class SideEffectTrait : public OpTraitBase<SideEffectTrait> {};

} // namespace pir

IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait)
Expand All @@ -118,3 +124,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::OneResultTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SideEffectTrait)
5 changes: 3 additions & 2 deletions paddle/pir/dialect/control_flow/ir/cf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
#include <functional>
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/op_trait.h"

namespace pir {
class IR_API YieldOp : public Op<YieldOp> {
class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
public:
using Op::Op;
static const char *name() { return "cf.yield"; }
Expand Down Expand Up @@ -56,7 +57,7 @@ class IR_API CreateStackOp : public Op<CreateStackOp> {
void Print(pir::IrPrinter &printer); // NOLINT
};

class IR_API PushBackOp : public Op<PushBackOp> {
class IR_API PushBackOp : public Op<PushBackOp, SideEffectTrait> {
public:
using Op::Op;
static const char *name() { return "cf.push_back"; }
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ TEST(while_op_test, network_with_backward) {
auto bwd_cond = builder.Build<pir::HasElementsOp>(stack).out();
auto while_grad = builder.Build<WhileOp>(
bwd_cond, std::vector<pir::Value>{x_out_grad, zero});
pir::Block* bwd_body_block = while_op.body_block();
pir::Block* bwd_body_block = while_grad.body_block();
builder.SetInsertionPointToStart(bwd_body_block);
auto local_x_out_grad_arg = bwd_body_block->AddArgument(x.type());
auto local_y_grad_arg = bwd_body_block->AddArgument(y.type());
Expand Down

0 comments on commit ff72c36

Please sign in to comment.