Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Refine the Build interface of split op #56924

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/program_interpreter.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/core/flags.h"

#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"

Expand All @@ -29,10 +29,7 @@

PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api);

PADDLE_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"new ir kernel program apply inplace pass.");
PHI_DECLARE_bool(new_ir_apply_inplace_pass);

namespace paddle {
namespace framework {
Expand Down
79 changes: 70 additions & 9 deletions paddle/fluid/ir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

# generator build function
_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'}

OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
{get_attributes}
Expand Down Expand Up @@ -273,6 +277,7 @@ def GenBuildAttributes(


def GenBuildOutputs(
op_class_name,
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
Expand Down Expand Up @@ -318,6 +323,40 @@ def GenBuildOutputs(
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""

CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()));
}}
else {{
PADDLE_ENFORCE(
{name}_.type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet("section Type should be VectorType."));
size_t {name}_size = {name}_.type().dyn_cast<ir::VectorType>().size();
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
{name}.SetFromTensor(true);
}}\n"""

CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>()));
}}
else {{
{name} = std::move(phi::Scalar(-1));
{name}.SetFromTensor(true);
}}\n"""

CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
"""
Expand Down Expand Up @@ -350,19 +389,30 @@ def GenBuildOutputs(
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
else:
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
else:
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += ""
Expand Down Expand Up @@ -423,9 +473,19 @@ def GenBuildOutputs(
CREATE_INFER_META_FUNC_TEMPLATE = """
phi::{func}({args});
"""
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """
phi::{func}({args}, phi::MetaConfig(false, false));
"""
if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG:
build_output_str += (
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
)
else:
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)

# use dense_{name} or vec_dense_{name} to create Outputs type
build_output_str += "\n std::vector<ir::Type> argument_outputs;"
Expand Down Expand Up @@ -530,6 +590,7 @@ def gen_build_func_str(
op_non_mutable_attribute_type_list,
)
build_outputs_str = GenBuildOutputs(
op_class_name,
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
Expand Down
20 changes: 11 additions & 9 deletions paddle/fluid/ir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_registry.h"

namespace details {
// NOTE(zhangbo): Which kind of value can be deleted?
// (1) Value's type needs to be AllocatedDenseTensorType or
// AllocatedSelectedRowsType; (2) Value's is not persisable.
bool CanBeDeleted(ir::Value value) {
static bool CanBeDeleted(ir::Value value) {
if (!value.type()) {
return false;
}
Expand All @@ -47,9 +48,9 @@ bool CanBeDeleted(ir::Value value) {
return true;
}

bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
ir::Value input,
ir::Value output) {
static bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
ir::Value input,
ir::Value output) {
if (input.type() != output.type()) {
VLOG(9) << " -- input's type != output's type, can't do inplace";
return false;
Expand All @@ -61,7 +62,7 @@ bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
return true;
}

bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
static bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) != 0) {
VLOG(8) << op->name()
Expand Down Expand Up @@ -90,7 +91,7 @@ bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {

// NOTE(zhangbo): pd.feed's output and pd.fetch's input can not be eager
// deleted.
std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
static std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels;
for (auto& op : *block) {
if (op->dialect()->name().compare(
Expand Down Expand Up @@ -119,7 +120,7 @@ std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
// NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator
// is supported. Therefore, this function only returns the values in the
// kernel_dialect operator that can be eager deleted.
std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>>
static std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>>
GetEagerDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels = GetSkipDeletionValues(block);

Expand Down Expand Up @@ -167,7 +168,7 @@ GetEagerDeletionValues(ir::Block* block) {
return eager_dels;
}

std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
static std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
ir::Block* block) {
const auto eager_dels = GetEagerDeletionValues(block);

Expand Down Expand Up @@ -282,6 +283,7 @@ std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
}
return inplace_ops;
}
} // namespace details

class InplacePass : public ir::Pass {
public:
Expand All @@ -292,7 +294,7 @@ class InplacePass : public ir::Pass {
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto* block = module_op.block();

auto inplace_ops = GetInplaceOps(block);
auto inplace_ops = details::GetInplaceOps(block);

for (auto kv : inplace_ops) {
VLOG(6) << "Do inplace for: "
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/core/ir_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class IrContextImpl {
<< ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
return iter->second;
}
LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
VLOG(8) << "No cache found operation of: [Name=" << name << "].";
return OpInfo();
}
const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
Expand Down
15 changes: 14 additions & 1 deletion paddle/phi/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
"Enable new IR API in Python");

/**
* Using new IR in executor FLAG
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_trace_run
* Since Version: 2.6.0
* Value Range: bool, default=false
Expand All @@ -1301,6 +1301,19 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run,
false,
"Enable new IR in executor");

/**
* Apply inplace pass to new IR FLAG
* Name: new_ir_apply_inplace_pass
* Since Version: 2.6.0
* Value Range: bool, default=true
* Example:
* Note: If Ture, will apply inplace pass to new IR.
*/
PHI_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"Whether to apply inplace pass on lowering "
"::ir::Program to Kernel Dialect");

PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder");

PHI_DEFINE_EXPORTED_bool(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def relu(x, name=None):
if in_dynamic_mode():
return _C_ops.relu(x)
else:
if paddle.ir.core._use_new_ir_api():
if paddle.framework.in_dynamic_or_new_ir_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么改成这个接口,允许动态图进入吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是:

if in_dynamic_or_new_ir_mode():
    return _C_ops.relu(x)
else:
    # 老静态图分支

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我单独提 pr 完善一下

# Below code will be removed after we can generate IR api automatically
return paddle._ir_ops.relu(x)

Expand Down
3 changes: 1 addition & 2 deletions test/ir/new_ir/test_pd_inplace_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
import numpy as np

import paddle
from paddle.fluid import core

paddle.enable_static()


class TestPdInplacePass(unittest.TestCase):
def test_pd_inplace_pass(self):
place = core.Place()
place = paddle.framework.core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
Expand Down