Skip to content

Commit

Permalink
[NewIR]Support Build(GroupPtr) Logic in NewIRCompiler and Add UT (Pad…
Browse files Browse the repository at this point in the history
…dlePaddle#56960)

* [NewIR]Support Build(GroupOps) in NewIRCompiler and Add UT

* fix unittest
  • Loading branch information
Aurelius84 authored and BeingGod committed Sep 9, 2023
1 parent d71d8f9 commit 00a9ffd
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 68 deletions.
25 changes: 16 additions & 9 deletions paddle/cinn/hlir/framework/new_ir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <string>
#include <vector>

#include "paddle/cinn/hlir/framework/new_ir/utils.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/ir/core/operation.h"

Expand All @@ -30,20 +31,26 @@ struct Group {
public:
explicit Group(const std::vector<::ir::Operation*>& group_ops)
: ops(group_ops) {
op_pattern_kind = OpPatternKind::kElementWise;
fn_name = "fn_";
for (auto& op : group_ops) {
fn_name += "_" + op->name();
}
Initialize();
}

explicit Group(std::initializer_list<::ir::Operation*> group_ops)
: ops(group_ops) {
Initialize();
}

int group_id;
std::string fn_name;
OpPatternKind op_pattern_kind;
std::vector<::ir::Operation*> ops;
std::vector<std::string> input_names;
std::vector<std::string> output_names;
int group_id;
// FIXME(Aurelius84): This should be refactored with CinnGroupOp
OpPatternKind op_pattern_kind;
std::string fn_name;

private:
void Initialize() {
op_pattern_kind = OpPatternKind::kElementWise;
fn_name = CompatibleInfo::GroupOpsName(ops);
}
};

} // namespace newir
Expand Down
26 changes: 17 additions & 9 deletions paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ir::Tensor GetTensor(const ::ir::Value& value) {
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims());
auto dtype = type_info.dtype();
std::string input_id = CompatibleInfo::InputName(value);
std::string input_id = CompatibleInfo::ValueName(value);
return lang::CreatePlaceHolder(
in_shape, utils::ConvertIRType(dtype), input_id);
}
Expand All @@ -56,15 +56,16 @@ std::vector<ir::Tensor> CollectInputTensor(
for (auto& operand : op->operands()) {
CHECK(operand);
auto in_value = operand.source();
ir::Tensor tensor;
VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value);
// NOTE(Aurelius84): Need always to create placeholder for input tensor.
ir::Tensor tensor = details::GetTensor(in_value);
if (!tensor_map->count(in_value)) {
tensor = details::GetTensor(in_value);
// record tensor.
(*tensor_map)[in_value] = tensor;
// record func input args
if (func_args != nullptr) func_args->push_back(tensor);
} else {
tensor = tensor_map->at(in_value);
if (func_args != nullptr) {
func_args->push_back(tensor);
}
}
tensors.push_back(tensor);
}
Expand All @@ -76,7 +77,7 @@ void CollectOutputInfo(const ::ir::Operation* op,
std::vector<std::vector<int>>* out_shapes) {
auto op_results = op->results();
for (auto& out_value : op_results) {
std::string output_id = CompatibleInfo::OutputName(out_value);
std::string output_id = CompatibleInfo::ValueName(out_value);
// group->output_names.push_back(output_id);
auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
Expand Down Expand Up @@ -265,11 +266,11 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
// output arg tensors
group_func_arg_tensors->push_back(tensor);
// output args
group->output_names.push_back(tensor->name);
group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
arg_name_set.insert(tensor->buffer->name);
}
}

if (!done_op_schedule) {
std::unordered_set<std::string> args_set;
for (auto arg : group_func_args) {
Expand Down Expand Up @@ -329,6 +330,8 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(

std::vector<ir::Tensor> op_func_arg_tensors =
details::CollectInputTensor(op, group_func_arg_tensors, tensor_map);
VLOG(4) << "input size:" << op_func_arg_tensors.size();

std::string cinn_op_name = CompatibleInfo::OpName(*op);
const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name);
auto op_impl = OpStrategy::SelectImpl(strategy[cinn_op](
Expand All @@ -348,6 +351,9 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(
}
}

VLOG(4) << "group_func_arg_tensors.size(): "
<< group_func_arg_tensors->size();

return func_bodies;
}

Expand All @@ -364,7 +370,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
// set tensor name = operand hash name
auto op_results = op->results();
for (const auto& result : op_results) {
std::string output_id = CompatibleInfo::OutputName(result);
std::string output_id = CompatibleInfo::ValueName(result);
cinn_inputs.push_back(common::CINNValue(output_id));
}

Expand Down Expand Up @@ -400,6 +406,8 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
}
}

VLOG(4) << "op_func_arg_tensors.size(): " << op_func_arg_tensors->size();

// 2.Do lower
std::string lower_fn_name = CompatibleInfo::OpFuncName(*op);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(lower_fn_name,
Expand Down
17 changes: 6 additions & 11 deletions paddle/cinn/hlir/framework/new_ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ std::string CompatibleInfo::OpName(const ::ir::Operation& op) {
return cinn_op_name;
}

std::string CompatibleInfo::InputName(const ::ir::Value& value) {
return CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
}

std::string CompatibleInfo::OutputName(const ::ir::Value& value) {
return CompatibleInfo::kOutputPrefix +
std::string CompatibleInfo::ValueName(const ::ir::Value& value) {
return CompatibleInfo::kNamePrefix +
std::to_string(std::hash<::ir::Value>()(value));
}

Expand All @@ -55,10 +50,10 @@ std::string CompatibleInfo::OpFuncName(const ::ir::Operation& op) {

std::string CompatibleInfo::GroupOpsName(
const std::vector<::ir::Operation*>& ops) {
std::string name = "fn_";
std::string name = "fn";
for (auto* op : ops) {
std::string op_name = OpName(*op);
name += cinn::common::Context::Global().NewName(op_name);
name += "_" + cinn::common::Context::Global().NewName(op_name);
}
return name;
}
Expand All @@ -69,7 +64,7 @@ std::vector<std::string> CompatibleInfo::InputNames(const ::ir::Operation& op,
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand_source(i);
std::string name = CompatibleInfo::InputName(value);
std::string name = CompatibleInfo::ValueName(value);
if (!allow_duplicate && repeat.count(name)) {
continue;
}
Expand All @@ -84,7 +79,7 @@ std::vector<std::string> CompatibleInfo::OutputNames(
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
std::string name = CompatibleInfo::OutputName(value);
std::string name = CompatibleInfo::ValueName(value);
names.push_back(std::move(name));
}
return names;
Expand Down
7 changes: 2 additions & 5 deletions paddle/cinn/hlir/framework/new_ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,14 @@ namespace framework {
namespace newir {

struct CompatibleInfo {
static constexpr char* kInputPrefix = "input_";
static constexpr char* kOutputPrefix = "output_";
static constexpr char* kNamePrefix = "var_";
// TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
// macros or attempt to unify Op name with Paddle and CINN.
static const std::unordered_map<std::string, std::string> OP_NAMES;

static std::string OpName(const ::ir::Operation& op);

static std::string InputName(const ::ir::Value& value);

static std::string OutputName(const ::ir::Value& value);
static std::string ValueName(const ::ir::Value& value);

static std::string OpFuncName(const ::ir::Operation& op);

Expand Down
35 changes: 14 additions & 21 deletions paddle/cinn/hlir/framework/new_ir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ std::unique_ptr<Program> NewIRCompiler::Build() {
++it) {
std::vector<::ir::Operation*> ops = {*it};
groups.push_back(std::make_shared<newir::Group>(ops));
groups.back()->fn_name = CompatibleInfo::GroupOpsName(groups.back()->ops);
}
VLOG(4) << "Groups size: " << groups.size();
return std::move(Build(groups));
Expand Down Expand Up @@ -103,23 +102,20 @@ std::vector<std::unique_ptr<Instruction>> NewIRCompiler::BuildInstructions(
const std::vector<newir::GroupPtr>& groups) {
std::vector<std::unique_ptr<Instruction>> instructions;
for (int idx = 0; idx < groups.size(); ++idx) {
// TODO(Aurelius84): only support single op in groups
auto& op = *(groups[idx]->ops[0]);

auto& fn_name = groups[idx]->fn_name;
auto instr = std::unique_ptr<Instruction>(
new Instruction(target_,
scope_.get(),
CompatibleInfo::InputNames(op),
CompatibleInfo::OutputNames(op),
fn_name));
auto instr =
std::unique_ptr<Instruction>(new Instruction(target_,
scope_.get(),
groups[idx]->input_names,
groups[idx]->output_names,
fn_name));
VLOG(1) << "Lookup kernel name: " << fn_name;
auto* fn_ptr = compiler_->Lookup(fn_name);
CHECK(fn_ptr);
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), fn_name);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name);
// SetSubKernels(instr.get(), fn_name);
instr->Finalize();
instructions.push_back(std::move(instr));
}
Expand All @@ -131,16 +127,15 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
std::unordered_set<::ir::Value> visited;
auto scope = std::make_shared<Scope>();

auto create_var = [&](const std::string& name_prefix, ::ir::Value value) {
auto create_var = [&](::ir::Value value) {
if (visited.count(value) > 0) return;
visited.emplace(value);

std::string name =
name_prefix + std::to_string(std::hash<::ir::Value>()(value));
std::string name = CompatibleInfo::ValueName(value);
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
// NOTE: can be replaced with phi::vectorized ?

std::vector<Shape::dim_t> shape;
for (auto i = 0; i < type_info.dims().size(); ++i) {
shape.push_back(Shape::dim_t(type_info.dims()[i]));
Expand All @@ -150,14 +145,12 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
};

for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand_source(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
for (auto& oprand : (*it)->operands()) {
create_var(oprand.source());
}

for (auto i = 0; i < (*it)->num_results(); ++i) {
auto out_value = (*it)->result(i);
create_var(CompatibleInfo::kOutputPrefix, out_value);
for (auto& result : (*it)->results()) {
create_var(result);
}
}
return scope;
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/new_ir_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class NewIRCompiler final {

std::unique_ptr<Program> Build();

std::unique_ptr<Program> Build(const std::vector<newir::GroupPtr>& groups);

private:
CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler);

std::unique_ptr<Program> Build(const std::vector<newir::GroupPtr>& groups);

std::vector<ir::LoweredFunc> GetOpFunc(const ::ir::Operation& op, int idx);

void ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_funcs);
Expand Down
Loading

0 comments on commit 00a9ffd

Please sign in to comment.