diff --git a/paddle/fluid/ir/CMakeLists.txt b/paddle/fluid/ir/CMakeLists.txt index d778d433528817..6186c2df6336fc 100644 --- a/paddle/fluid/ir/CMakeLists.txt +++ b/paddle/fluid/ir/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(interface) +add_subdirectory(trait) add_subdirectory(dialect) add_subdirectory(transforms) add_subdirectory(phi_kernel_adaptor) diff --git a/paddle/fluid/ir/dialect/CMakeLists.txt b/paddle/fluid/ir/dialect/CMakeLists.txt index a54938ed6fa896..77dfaa85251533 100644 --- a/paddle/fluid/ir/dialect/CMakeLists.txt +++ b/paddle/fluid/ir/dialect/CMakeLists.txt @@ -52,5 +52,5 @@ file(GLOB PD_DIALECT_SRCS "*.cc") cc_library( pd_dialect SRCS ${PD_DIALECT_SRCS} ${op_source_file} - DEPS framework_proto phi phi_utils pd_interface ir) + DEPS framework_proto phi phi_utils pd_interface pd_trait ir) target_include_directories(pd_dialect PRIVATE ${PD_DIALECT_BINARY_DIR}) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 4b25921c8bb3cf..5ae3306c476539 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -43,6 +43,7 @@ #include "paddle/fluid/ir/dialect/op_yaml_info_util.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/infermeta.h" +#include "paddle/fluid/ir/trait/inplace.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" @@ -713,6 +714,10 @@ def OpGenerator( op_interfaces_str = "" if len(op_interfaces) > 0: op_interfaces_str = "," + ",".join(op_interfaces) + + if op_name[-1] == "_": + op_traits += ["InplaceTrait"] + op_traits_str = "" if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) diff --git a/paddle/fluid/ir/interface/op_yaml_info_parser.cc b/paddle/fluid/ir/interface/op_yaml_info_parser.cc index 7373eeccd38709..b21e4f82a70cc9 100644 --- a/paddle/fluid/ir/interface/op_yaml_info_parser.cc +++ b/paddle/fluid/ir/interface/op_yaml_info_parser.cc @@ -84,8 +84,30 @@ const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const { return std::get<3>(op_info_tuple_); } -const std::map& OpYamlInfoParser::Name2Id() const { - return name2id_; +const std::map& OpYamlInfoParser::InputName2Id() const { + return input_name2id_; +} + +bool OpYamlInfoParser::HasInplace(const std::string& out_name) const { + auto inplace_info = std::get<3>(op_info_tuple_).inplace; + for (size_t i = 0; i < inplace_info.size(); i++) { + if (out_name == inplace_info[i].first) { + return true; + } + } + return false; +} + +const std::string& OpYamlInfoParser::InplaceName( + const std::string& out_name) const { + auto inplace_info = std::get<3>(op_info_tuple_).inplace; + for (size_t i = 0; i < inplace_info.size(); i++) { + if (out_name == inplace_info[i].first) { + return inplace_info[i].second; + } + } + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Can not find inplace input of [%s].", out_name)); } void OpYamlInfoParser::parse() { @@ -94,30 +116,30 @@ void OpYamlInfoParser::parse() { int start_index = 0; for (size_t i = 0; i < input_info.size(); ++i) { - name2id_[input_info[i].name] = start_index++; - + input_name2id_[input_info[i].name] = start_index++; + input_name_list_.push_back(input_info[i].name); + input_info_[input_info[i].name] = input_info[i]; if (!input_info[i].is_mutable_attribute) { input_tensor_number_++; } - - input_info_[input_info[i].name] = input_info[i]; } auto attribute_info = std::get<1>(op_info_tuple_); for (size_t i = 0; i < attribute_info.size(); ++i) { + attribute_name_list_.push_back(attribute_info[i].name); attr_info_[attribute_info[i].name] = attribute_info[i]; } auto output_info = std::get<2>(op_info_tuple_); - for (size_t i = 0; i < output_info.size(); ++i) { + output_name_list_.push_back(output_info[i].name); output_info_[output_info[i].name] = output_info[i]; } auto runtime_info = std::get<3>(op_info_tuple_); for (auto& name : runtime_info.infer_meta_param) { - if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) { + if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) { infer_meta_tensor_params_.push_back(name); } else { infer_meta_attr_params_.push_back(name); @@ -125,7 +147,7 @@ void OpYamlInfoParser::parse() { } for (auto& name : runtime_info.kernel_param) { - if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) { + if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) { kernel_fn_tensor_params_.push_back(name); } else { kernel_fn_attr_params_.push_back(name); diff --git a/paddle/fluid/ir/interface/op_yaml_info_parser.h b/paddle/fluid/ir/interface/op_yaml_info_parser.h index f95ae7b2106000..b2897b0fc2ecd6 100644 --- a/paddle/fluid/ir/interface/op_yaml_info_parser.h +++ b/paddle/fluid/ir/interface/op_yaml_info_parser.h @@ -34,7 +34,21 @@ class OpYamlInfoParser { const std::vector& TensorParams(bool is_kernel = false) const; const std::vector& AttrParams(bool is_kernel = false) const; const OpRunTimeInfo& OpRuntimeInfo() const; - const std::map& Name2Id() const; + const std::map& InputName2Id() const; + + const std::vector& InputNames() const { + return input_name_list_; + } + const std::vector& AttributeNames() const { + return attribute_name_list_; + } + const std::vector& OutputNames() const { + return output_name_list_; + } + + bool HasInplace(const std::string& out_name) const; + + const std::string& InplaceName(const std::string& out_name) const; private: void parse(); @@ -44,18 +58,25 @@ class OpYamlInfoParser { OpInfoTuple op_info_tuple_; - std::map name2id_; - + // input info + std::map input_name2id_; + std::vector input_name_list_; std::map input_info_; + int input_tensor_number_{0}; + + // attribute info + std::vector attribute_name_list_; std::map attr_info_; + + // output info + std::vector output_name_list_; std::map output_info_; + // runtime info std::vector infer_meta_tensor_params_; std::vector infer_meta_attr_params_; std::vector kernel_fn_tensor_params_; std::vector kernel_fn_attr_params_; - - int input_tensor_number_{0}; }; } // namespace dialect diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 0c63781e9cb4bb..bc454897e73b99 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -64,6 +64,66 @@ paddle::framework::Variable* CreateVar(ir::Value value, } } +void CheckInputVars( + ir::Operation* op, + const std::string& op_name, + const std::unordered_map& name_map) { + size_t input_num = op->num_operands(); + if (input_num > 0) { + for (size_t i = 0; i < input_num; ++i) { + auto value = op->operand(i); + if (value) { + PADDLE_ENFORCE_NE( + name_map.find(value), + name_map.end(), + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + op_name)); + } + } + } +} + +void BuildValue(ir::Value value, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map, + int& count) { // NOLINT + auto inner_local_scope = local_scope != nullptr ? local_scope : scope; + std::string name; + if (name_map->find(value) != name_map->end()) { + name = name_map->at(value); + } else { + name = "inner_var_" + std::to_string(count++); + name_map->emplace(value, name); + } + auto var = CreateVar(value, name, scope, inner_local_scope); + // Only support DenseTensor or Vector + if (!value.type()) { + var->GetMutable(); + } else if (value.type().isa()) { + var->GetMutable(); + } else if (value.type().isa()) { + auto tensor_array = var->GetMutable(); + for (size_t i = 0; i < value.type().dyn_cast().size(); + i++) { + PADDLE_ENFORCE(value.type() + .dyn_cast()[i] + .isa(), + paddle::platform::errors::Fatal( + "Element of VectorType output only support " + "DenseTensorType")); + std::string name_i = "inner_var_" + std::to_string(count++); + auto var_i = CreateVar(value, name_i, scope, inner_local_scope); + tensor_array->emplace_back(var_i->GetMutable()); + } + } else { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Output only support DenseTensorType or VectorType")); + } +} + void HandleForSpecialOp(ir::Operation* op, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, @@ -91,10 +151,10 @@ void HandleForSpecialOp(ir::Operation* op, if (op_name == "pd.feed") { VLOG(6) << "Handle for pd.feed:"; - auto ptr = op->result(0); + auto value = op->result(0); std::string name = "inner_var_" + std::to_string(count++); - name_map->emplace(ptr, name); - auto var = CreateVar(ptr, name, scope, local_scope); + name_map->emplace(value, name); + auto var = CreateVar(value, name, scope, local_scope); // TODO(phlrain): need to update here, support StringTensor auto out_tensor = var->GetMutable(); @@ -122,14 +182,14 @@ void HandleForSpecialOp(ir::Operation* op, auto tensor_array = var->GetMutable(); for (size_t i = 0; i < input_num; ++i) { - auto ptr = op->operand(i); + auto value = op->operand(i); PADDLE_ENFORCE_EQ( - name_map->count(ptr), + name_map->count(value), true, phi::errors::PreconditionNotMet("can not found input of combine op")); tensor_array->emplace_back( - &(CreateVar(ptr, name_map->at(ptr), scope, local_scope) + &(CreateVar(value, name_map->at(value), scope, local_scope) ->Get())); } } @@ -160,6 +220,41 @@ void HandleForSpecialOp(ir::Operation* op, } } +void HandleForInplaceOp(ir::Operation* op, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map, + int& count) { // NOLINT + if (op->num_results() < 1) return; + ir::IrContext* ctx = ir::IrContext::Instance(); + std::string op_name = op->name(); + if (op->attributes().count("op_name")) { + op_name = + op->attributes().at("op_name").dyn_cast().data(); + } + VLOG(4) << "HandleForInplaceOp: " << op_name; + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); + paddle::dialect::OpYamlInfoParser yaml_parser( + op_info.GetInterfaceImpl() + ->get_op_info_()); + + for (size_t i = 0; i < op->num_results(); ++i) { + ir::Value value = op->result(i); + std::string value_name = yaml_parser.OutputNames()[i]; + if (yaml_parser.HasInplace(value_name)) { + std::string inplace_name = yaml_parser.InplaceName(value_name); + ir::Value inplace_value = + op->operand(yaml_parser.InputName2Id().at(inplace_name)); + std::string var_name = name_map->at(inplace_value); + VLOG(4) << "inplace: " << value_name << " -> " << inplace_name + << " (var: " << var_name << ")"; + name_map->emplace(value, var_name); + } else { + BuildValue(value, scope, local_scope, name_map, count); + } + } +} + void BuildScope(const ir::Block& block, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, @@ -178,77 +273,39 @@ void BuildScope(const ir::Block& block, for (auto it = block.begin(); it != block.end(); ++it) { ir::Operation* op = *it; - auto attr_map = op->attributes(); std::string op_name = op->name(); - if (attr_map.count("op_name")) { - op_name = attr_map.at("op_name").dyn_cast().data(); + if (op->attributes().count("op_name")) { + op_name = + op->attributes().at("op_name").dyn_cast().data(); } + VLOG(4) << "BuildScope for :" << op_name; + if (op_name == "pd.feed" || op_name == "pd.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.get_parameter") { - VLOG(6) << "HandleForSpecialOp: " << op_name; + VLOG(4) << "HandleForSpecialOp: " << op_name; HandleForSpecialOp(op, scope, inner_local_scope, name_map, count); continue; } - size_t input_num = op->num_operands(); - if (input_num > 0) { - for (size_t i = 0; i < input_num; ++i) { - auto ptr = op->operand(i); - if (ptr) { - PADDLE_ENFORCE_NE( - name_map->find(ptr), - name_map->end(), - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - op_name)); - } - } - } + CheckInputVars(op, op_name, *name_map); - int out_num = op->num_results(); - if (out_num > 0) { - for (int i = 0; i < out_num; ++i) { - ir::Value ptr = op->result(i); - std::string name; - if (name_map->find(ptr) != name_map->end()) { - name = name_map->at(ptr); - } else { - name = "inner_var_" + std::to_string(count++); - name_map->emplace(ptr, name); - } - auto var = CreateVar(ptr, name, scope, inner_local_scope); - // Only support DenseTensor or Vector - if (!ptr.type()) { - var->GetMutable(); - } else if (ptr.type() - .isa()) { - var->GetMutable(); - } else if (ptr.type().isa()) { - auto tensor_array = - var->GetMutable(); - for (size_t i = 0; i < ptr.type().dyn_cast().size(); - i++) { - PADDLE_ENFORCE( - ptr.type() - .dyn_cast()[i] - .isa(), - paddle::platform::errors::Fatal( - "Element of VectorType output only support " - "DenseTensorType")); - std::string name_i = "inner_var_" + std::to_string(count++); - auto var_i = CreateVar(ptr, name_i, scope, inner_local_scope); - tensor_array->emplace_back(var_i->GetMutable()); - } - } else { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Output only support DenseTensorType or VectorType")); - } + if (op->num_results() < 1) continue; + if (op->attributes().count("is_inplace") != 0 && + op->attributes() + .at("is_inplace") + .dyn_cast() + .data()) { + HandleForInplaceOp(op, scope, inner_local_scope, name_map, count); + continue; + } else { + for (size_t i = 0; i < op->num_results(); ++i) { + BuildValue(op->result(i), scope, local_scope, name_map, count); } } } + VLOG(4) << "***** [after build] scope: ******\n" << paddle::framework::GenScopeTreeDebugInfo( const_cast(scope->root())); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 5e92e27e2e17a8..32ffa663144d8d 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -45,12 +45,27 @@ paddle::framework::Variable* CreateVar(ir::Value value, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope); +void BuildValue(ir::Value value, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map, + int& count); // NOLINT + void HandleForSpecialOp(ir::Operation* op, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, std::unordered_map* name_map, int& count); // NOLINT +void HandleForInplaceOp(ir::Operation* op, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map, + int& count); // NOLINT + +void CheckInputVars(ir::Operation* op, + const std::unordered_map& name_map); + void BuildScope(const ir::Block& block, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, @@ -80,13 +95,13 @@ void BuildPhiContext( auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(is_kernel); - auto& name2id = op_yaml_info.Name2Id(); + auto& name2id = op_yaml_info.InputName2Id(); for (auto& t : vec_kernel_fn_tensor_params) { PADDLE_ENFORCE_EQ( name2id.count(t), true, phi::errors::NotFound("param [%s] MUST in name2id map", t)); - auto index = op_yaml_info.Name2Id().at(t); + auto index = op_yaml_info.InputName2Id().at(t); ir::Value ptr = op->operand(index); if (!ptr) { phi::DenseTensor* ptr = nullptr; @@ -97,7 +112,7 @@ void BuildPhiContext( auto in_var_name = name_map.at(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; - PADDLE_ENFORCE_NOT_NULL(inner_scope->FindLocalVar(in_var_name), + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), phi::errors::PreconditionNotMet( "can not find var[%s] in scope", in_var_name)); auto var = inner_scope->FindVar(in_var_name); diff --git a/paddle/fluid/ir/trait/CMakeLists.txt b/paddle/fluid/ir/trait/CMakeLists.txt new file mode 100644 index 00000000000000..949c44bf53ed10 --- /dev/null +++ b/paddle/fluid/ir/trait/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB PD_INTERFACE_SRCS "*.cc") + +cc_library( + pd_trait + SRCS ${PD_INTERFACE_SRCS} + DEPS ir) diff --git a/paddle/fluid/ir/trait/inplace.h b/paddle/fluid/ir/trait/inplace.h new file mode 100644 index 00000000000000..38dfaaeac000ef --- /dev/null +++ b/paddle/fluid/ir/trait/inplace.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/op_base.h" + +namespace paddle { +namespace dialect { +class InplaceTrait : public ir::OpTraitBase { + public: + explicit InplaceTrait(ir::Operation *op) + : ir::OpTraitBase(op) {} +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait) diff --git a/paddle/fluid/ir/trait/trait.cc b/paddle/fluid/ir/trait/trait.cc new file mode 100644 index 00000000000000..e20e25cda9e16a --- /dev/null +++ b/paddle/fluid/ir/trait/trait.cc @@ -0,0 +1,17 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/ir/trait/inplace.h" + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait) diff --git a/paddle/fluid/ir/transforms/CMakeLists.txt b/paddle/fluid/ir/transforms/CMakeLists.txt index d131b458dddf1c..83e508dbd4099a 100644 --- a/paddle/fluid/ir/transforms/CMakeLists.txt +++ b/paddle/fluid/ir/transforms/CMakeLists.txt @@ -1,9 +1,9 @@ cc_library( transform_general_functions SRCS transform_general_functions.cc - DEPS ir phi pd_dialect) + DEPS phi pd_dialect ir) cc_library( pd_op_to_kernel_pass SRCS pd_op_to_kernel_pass.cc - DEPS ir phi_utils pd_interface) + DEPS phi_utils pd_interface pd_trait ir) diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index e5471e0d873ddf..d31d316ac222c1 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -26,11 +26,13 @@ #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info_parser.h" +#include "paddle/fluid/ir/trait/inplace.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_factory.h" + namespace paddle { namespace dialect { @@ -63,7 +65,7 @@ phi::KernelKey GetKernelKey( if (data_type_info.size() > 0 && data_type_info[0] != "") { // only support single input and attribute auto slot_name = data_type_info[0]; - auto& input_map = op_info_parser->Name2Id(); + auto& input_map = op_info_parser->InputName2Id(); if (input_map.count(slot_name)) { // parse from input @@ -340,6 +342,10 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog) { op_attribute.emplace(it1->first, it1->second); } + if ((*it)->HasTrait()) { + op_attribute.emplace("is_inplace", ir::BoolAttribute::get(ctx, true)); + } + ir::Operation* op = ir::Operation::Create( vec_inputs, op_attribute, op_output_types, op_info); diff --git a/test/cpp/ir/kernel_dialect/CMakeLists.txt b/test/cpp/ir/kernel_dialect/CMakeLists.txt index fd5842cd8d22ef..1cc167f783cf0b 100644 --- a/test/cpp/ir/kernel_dialect/CMakeLists.txt +++ b/test/cpp/ir/kernel_dialect/CMakeLists.txt @@ -6,6 +6,7 @@ cc_test_old( pd_op_to_kernel_pass pd_dialect phi_kernel_adaptor + pd_trait ir phi gtest) diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 8501980b251e7d..8f88a744d5c7ac 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -39,6 +39,7 @@ PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT); bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } @@ -246,5 +247,44 @@ TEST(StandaloneExecutor, data_transfer) { } #endif +TEST(StandaloneExecutor, run_inplace_sqrt) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + + paddle::dialect::FullOp full = builder.Build( + std::vector{2, 2}, 4.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + builder.Build(full->result(0)); + + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + kernel_program->Print(std::cout); + + auto place = platform::CPUPlace(); + Scope scope; + InterpreterCore test_core(place, std::move(kernel_program), &scope); + test_core.Run({}); + + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("inner_var_0")->Get() + : test_core.local_scope() + ->FindVar("inner_var_0") + ->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 2.0); + bool res1 = simple_cmp(out_tensor.data()[1], 2.0); + bool res2 = simple_cmp(out_tensor.data()[2], 2.0); + bool res3 = simple_cmp(out_tensor.data()[3], 2.0); + + EXPECT_EQ(scope.kids().size(), 1u); + EXPECT_EQ(scope.kids().front()->Size(), 1u); + EXPECT_EQ(res0, true); + EXPECT_EQ(res1, true); + EXPECT_EQ(res2, true); + EXPECT_EQ(res3, true); +} + } // namespace framework } // namespace paddle