Skip to content

Commit

Permalink
[pir] Support lowering to kernel for if_grad op (PaddlePaddle#59033)
Browse files Browse the repository at this point in the history
* support lower to kernel for if_grad op

* add PD_DECLARE_KERNEL

* fix
  • Loading branch information
chen2016013 authored and SecretXV committed Nov 28, 2023
1 parent 38f14ab commit 6fddfb0
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 3 deletions.
44 changes: 43 additions & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ const std::unordered_set<std::string> SpecialLowerOps = {
pir::YieldOp::name(),
IfOp::name(),
WhileOp::name(),
pir::CreateStackOp::name(),
pir::PushBackOp::name(),
pir::PopBackOp::name(),
"cinn_runtime.jit_kernel"};

static bool NeedFallBackCpu(const pir::Operation* op,
Expand Down Expand Up @@ -1017,6 +1020,45 @@ void HandleForSpecialOp(
}
}

if (op_item->isa<::pir::CreateStackOp>() ||
op_item->isa<::pir::PushBackOp>()) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
vec_inputs.push_back(new_in);
}
for (size_t i = 0; i < op_item->num_results(); ++i) {
op_output_types.push_back(op_item->result(i).type());
}
}

if (op_item->isa<::pir::PopBackOp>()) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
auto new_in = GetNewInput(
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
vec_inputs.push_back(new_in);
}

auto pop_back_op = op_item->dyn_cast<::pir::PopBackOp>();
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto cur_inlet_element = pop_back_op.inlet_element(i);
PADDLE_ENFORCE_EQ(map_value_pair->count(cur_inlet_element),
true,
phi::errors::PreconditionNotMet(
"[%d]'s output of [%s] op MUST be in map pair",
i,
op_item->name()));
auto new_inlet_element = map_value_pair->at(cur_inlet_element);

op_output_types.push_back(new_inlet_element.type());
}
}
if (op_item->name() == "cinn_runtime.jit_kernel") {
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
Expand Down Expand Up @@ -1049,7 +1091,7 @@ void HandleForSpecialOp(
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
VLOG(6) << "Deep copy a new special op: " << op_item->name();
}

std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
Expand Down
6 changes: 4 additions & 2 deletions paddle/pir/dialect/control_flow/ir/cf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,10 @@ void PopBackOp::VerifySig() {
"The pop elements size must equal to push elements size.");
for (size_t index = 0; index < inlet_size; ++index) {
IR_ENFORCE(outlet_element(index).type() == inlet_element(index).type(),
"The %d element's push type isn't equal to pop type",
index);
"The %d element's push type (%s) isn't equal to pop type (%s)",
index,
outlet_element(index).type(),
inlet_element(index).type());
}
VLOG(4) << "End Verifying for PopBackOp.";
}
Expand Down
4 changes: 4 additions & 0 deletions test/cpp/pir/control_flow_dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ cc_test_old(
DEPS
pir
op_dialect_vjp
pir_transforms
op_dialect
gtest)

cc_test_old(
Expand All @@ -14,4 +16,6 @@ cc_test_old(
DEPS
pir
op_dialect_vjp
pir_transforms
op_dialect
gtest)
13 changes: 13 additions & 0 deletions test/cpp/pir/control_flow_dialect/if_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@
#include <gtest/gtest.h>
#include <iostream>

#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT);

using namespace paddle::dialect; // NOLINT

TEST(if_op_test, base) {
Expand Down Expand Up @@ -99,6 +109,7 @@ TEST(if_op_test, network_with_backward) {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<pir::ControlFlowDialect>();
ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>();

pir::Program program(ctx);
pir::Block* block = program.block();
Expand Down Expand Up @@ -169,4 +180,6 @@ TEST(if_op_test, network_with_backward) {
builder.SetInsertionPointToEnd(block);

LOG(INFO) << program;

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
}

0 comments on commit 6fddfb0

Please sign in to comment.