Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cinn(py-dsl): add ir context used in python dsl #57515

Merged
merged 3 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name,
return LoweredFunc(n);
}

LoweredFunc _LoweredFunc_::Make(const std::string& name,
const std::vector<Argument>& args,
const Expr& body) {
auto* n = make_shared<_LoweredFunc_>();
n->name = name;
n->args = args;
n->body = body;
return LoweredFunc(n);
}

void _LoweredFunc_::CheckValid() const {
// check there is at least one output
int out_count = 0;
Expand Down
13 changes: 11 additions & 2 deletions paddle/cinn/ir/lowered_func.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ class _LoweredFunc_;
* the function signature of generated code.
*/
struct Argument {
//! Input or output.
enum class IO { kInput = 0, kOutput = 1 };
//! kInput: arg is input
//! kOutput: arg is output
//! kUnknown: arg maybe input or output
enum class IO { kInput = 0, kOutput = 1, kUnknown = 2 };

IO io{IO::kInput};

Expand Down Expand Up @@ -164,6 +166,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
const Expr& body,
const std::vector<ir::Buffer>& temp_bufs);

// A simple version of the make function method,
// regardless of the argument buffer information and IO information of
// Argument, after building the function to optimize the buffer through pass
static LoweredFunc Make(const std::string& name,
const std::vector<Argument>& args,
const Expr& body);

bool is_gpu_host() const { return cuda_axis_info.valid(); }

void Verify() const override {}
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ void Module::Builder::Clear() {
module_->submodules.clear();
}

Target::Arch Module::Builder::GetTargetArch() { return module_->target.arch; }

Module Module::Builder::Build() {
if (module_->functions.empty()) {
VLOG(1) << "Module has no functions";
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Module : public ir::IrNodeRef {
void AddFunctionWithoutOptim(const ir::LoweredFunc& func);
void AddBuffer(ir::Buffer buffer);
void Clear();
Target::Arch GetTargetArch();

Module Build();

Expand Down
17 changes: 17 additions & 0 deletions paddle/cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ Tensor _Tensor_::Make(const std::string &name,

return Tensor(n);
}
Tensor _Tensor_::Make(const std::string &name,
Type dtype,
const std::vector<Expr> &shape,
const std::vector<Expr> &domain,
const std::vector<Var> &reduce_axis) {
CHECK(!name.empty()) << "Cannot set empty Tensor name in Tensor::Make";
auto n = make_shared<_Tensor_>();
n->name = name;
n->shape = shape;
n->domain = domain;
n->reduce_axis = reduce_axis;
n->operation = PlaceholderOp::Make(n->name, n->shape, Float(32));
n->set_type(dtype);
n->InitAxis();

return Tensor(n);
}

size_t Tensor::ndims() const { return operator->()->shape.size(); }

Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ class _Tensor_ : public ExprNode<_Tensor_> {
FunctionRef fn,
const std::vector<Var>& reduce_axis = {});

// Manual tensor construction, no FunctionRef information
static Tensor Make(const std::string& name,
Type dtype,
const std::vector<Expr>& shape,
const std::vector<Expr>& domain,
const std::vector<Var>& reduce_axis = {});

void Verify() const override;

bool IsReduceInited(poly::StageMap stages) const;
Expand Down
86 changes: 57 additions & 29 deletions paddle/cinn/ir/utils/ir_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return true;
}

if (only_compare_structure_ && !lhs.defined() && !rhs.defined()) {
return true;
}

if (!lhs.defined() || !rhs.defined()) { // someone invalid
return false;
VLOG(5) << "Not equal on Expr, someone not defined";
Expand All @@ -46,10 +50,9 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return equal;
}

bool IrEqualVisitor::Compare(const std::string& lhs,
const std::string& rhs,
bool allow_name_suffix_diff) {
// if allow_name_suffix_diff=true then just compare the name prefix before the
bool IrEqualVisitor::Compare(const std::string& lhs, const std::string& rhs) {
// if allow_name_suffix_diff_=true then just compare the name prefix before
// the
// "_[0-9]+"
auto common_len = 0;
for (; common_len < lhs.size() && common_len < rhs.size(); ++common_len) {
Expand All @@ -67,7 +70,7 @@ bool IrEqualVisitor::Compare(const std::string& lhs,
equal = true;
} else {
equal = false;
if (allow_name_suffix_diff) {
if (allow_name_suffix_diff_) {
equal = is_endswith_index(lhs) && is_endswith_index(rhs);
}
}
Expand Down Expand Up @@ -181,17 +184,26 @@ bool IrEqualVisitor::Visit(const Block* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const Call* lhs, const Expr* other) {
auto* rhs = other->As<Call>();
return lhs->name == rhs->name && Compare(lhs->read_args, rhs->read_args) &&
Compare(lhs->write_args, rhs->write_args) &&
Compare(lhs->attrs, rhs->attrs) && lhs->call_type == rhs->call_type;
bool flag = Compare(lhs->read_args, rhs->read_args) &&
Compare(lhs->write_args, rhs->write_args) &&
Compare(lhs->attrs, rhs->attrs) &&
lhs->call_type == rhs->call_type;
if (only_compare_structure_) {
return flag;
}
return lhs->name == rhs->name && flag;
// TODO(CtfGo): Compare `func` field
}

bool IrEqualVisitor::Visit(const _Var_* lhs, const Expr* other) {
auto* rhs = other->As<_Var_>();
return lhs->name == rhs->name &&
Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) && lhs->tag == rhs->tag;
bool flag = Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) &&
lhs->tag == rhs->tag;
if (only_compare_structure_) {
return flag;
}
return lhs->name == rhs->name && flag;
}

bool IrEqualVisitor::Visit(const Load* lhs, const Expr* other) {
Expand Down Expand Up @@ -221,19 +233,25 @@ bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) {
auto* rhs = other->As<_Buffer_>();
return Compare(lhs->shape, rhs->shape) &&
Compare(lhs->strides, rhs->strides) && lhs->name == rhs->name &&
lhs->scope == rhs->scope &&
Compare(lhs->elem_offset, rhs->elem_offset) &&
lhs->offset_factor == rhs->offset_factor &&
lhs->target == rhs->target &&
lhs->data_alignment == rhs->data_alignment &&
lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype;
bool flag =
Compare(lhs->shape, rhs->shape) && Compare(lhs->strides, rhs->strides) &&
lhs->scope == rhs->scope && Compare(lhs->elem_offset, rhs->elem_offset) &&
lhs->offset_factor == rhs->offset_factor && lhs->target == rhs->target &&
lhs->data_alignment == rhs->data_alignment &&
lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype;
if (only_compare_structure_) {
return flag;
}
return flag && lhs->name == rhs->name;
}

bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) {
auto* rhs = other->As<_Tensor_>();
return lhs->name == rhs->name && Compare(lhs->shape, rhs->shape);
bool flag = Compare(lhs->shape, rhs->shape);
if (only_compare_structure_) {
return flag;
}
return flag && Compare(lhs->name, rhs->name);
}

bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {
Expand Down Expand Up @@ -282,10 +300,15 @@ bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const _Module_* lhs, const Expr* other) {
auto* rhs = other->As<_Module_>();
return lhs->name == rhs->name && lhs->target == rhs->target &&
Compare(lhs->buffers, rhs->buffers) &&
Compare(lhs->functions, rhs->functions) &&
Compare(lhs->submodules, rhs->submodules);
bool flag = Compare(lhs->buffers, rhs->buffers) &&
Compare(lhs->functions, rhs->functions) &&
Compare(lhs->submodules, rhs->submodules);

if (only_compare_structure_) {
return flag;
}

return flag && lhs->name == rhs->name;
}

bool IrEqualVisitor::Visit(const Let* lhs, const Expr* other) {
Expand Down Expand Up @@ -347,11 +370,16 @@ bool IrEqualVisitor::Visit(const _BufferRange_* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const ScheduleBlock* lhs, const Expr* other) {
auto* rhs = other->As<ScheduleBlock>();
return Compare(lhs->name, rhs->name, allow_name_suffix_diff_) &&
Compare(lhs->iter_vars, rhs->iter_vars) &&
Compare(lhs->read_buffers, rhs->read_buffers) &&
Compare(lhs->write_buffers, rhs->write_buffers) &&
Compare(lhs->attrs, rhs->attrs) && Compare(lhs->body, rhs->body);
bool flag = Compare(lhs->iter_vars, rhs->iter_vars) &&
Compare(lhs->read_buffers, rhs->read_buffers) &&
Compare(lhs->write_buffers, rhs->write_buffers) &&
Compare(lhs->body, rhs->body);

if (only_compare_structure_) {
return flag;
}
return flag && Compare(lhs->attrs, rhs->attrs) &&
Compare(lhs->name, rhs->name);
}

bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) {
Expand Down
12 changes: 7 additions & 5 deletions paddle/cinn/ir/utils/ir_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ namespace ir_utils {
// fields of each node through dfs visitor
class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false)
: allow_name_suffix_diff_(allow_name_suffix_diff) {}
explicit IrEqualVisitor(bool allow_name_suffix_diff = false,
bool only_compare_structure = false)
: allow_name_suffix_diff_(allow_name_suffix_diff),
only_compare_structure_(only_compare_structure) {}
// Return true if they are euqal, otherwise false;
bool Compare(const Expr& lhs, const Expr& rhs);

private:
bool Compare(const std::string& lhs,
const std::string& rhs,
bool allow_name_suffix_diff = false);
bool Compare(const std::string& lhs, const std::string& rhs);
bool Compare(const std::map<std::string, attr_t>& lhs,
const std::map<std::string, attr_t>& rhs);
template <typename T>
Expand All @@ -46,6 +46,8 @@ class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {

// whether allowing name suffix ends with "_[0-9]+" different
bool allow_name_suffix_diff_ = false;
// not compare name field of Expr
bool only_compare_structure_ = false;
};

bool IRCompare(const Expr& lhs,
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ set(srcs
runtime.cc
common.cc
lang.cc
ir.cc
ir/ir.cc
ir/ir_api.cc
ir/ir_context.cc
poly.cc
backends.cc
bind.cc
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/pybind/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ void BindTarget(py::module *m) {
void BindType(py::module *m) {
py::class_<Type> type(*m, "Type");
type.def(py::init<>())
.def(py::init<Type &>())
.def(py::init<Type::type_t, int, int, Type::specific_type_t>());
#define DEFINE_TYPE_METHOD(__name) (type = type.def(#__name, &Type::__name))
DEFINE_TYPE_METHOD(is_primitive);
Expand Down Expand Up @@ -140,7 +141,7 @@ void BindType(py::module *m) {
.export_values();

py::enum_<Type::specific_type_t> specific_type_t(type, "specific_type_t");
specific_type_t.value("None", Type::specific_type_t::None)
specific_type_t.value("UNK", Type::specific_type_t::None)
.value("FP16", Type::specific_type_t::FP16)
.value("BF16", Type::specific_type_t::BF16)
.export_values();
Expand Down
98 changes: 98 additions & 0 deletions paddle/cinn/pybind/ir/ir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/pybind/ir/ir.h"
#include "paddle/cinn/pybind/ir/ir_context.h"
namespace cinn {
namespace pybind {
void TensorStore(Expr tensor, Expr value, const std::vector<Expr>& indices) {
// TODO(6clc): Check the compatibility of data types for tensor and value
IRContext find_sch_block =
IRBuilder::CurrentIRBuilder()
.data_->FindContext<ScheduleBlockContextNode>();
if (!find_sch_block.data_.defined()) {
IRContext sch_block(new ScheduleBlockContextNode());
sch_block.data_->EnterWithContext();
LinkToParentContext(ir::Store::Make(tensor, value, indices));
sch_block.data_->ExitWithContext();
return;
}
LinkToParentContext(ir::Store::Make(tensor, value, indices));
}
std::vector<Expr> AxisMap(const std::string& kinds,
const std::vector<Expr>& iter_expression) {
std::vector<Expr> rets;
CHECK_EQ(kinds.size(), iter_expression.size());
int n = iter_expression.size();
rets.reserve(n);
for (int i = 0; i < n; i++) {
char c = kinds.c_str()[i];

// TODO(6clc): set bound of IterVar

Var iter_var = ir::_Var_::Make("iter_tmp", common::Int(32));
if (c == 'S') {
iter_var->is_reduce_axis = false;
} else if (c == 'R') {
iter_var->is_reduce_axis = true;
} else {
LOG(FATAL)
<< "kind of axis setting error, must be R(Reduce) or S(Spatial)";
}
rets.push_back(SetScheduleBlockIterVar(iter_var, iter_expression[i]));
}
return rets;
}
Var SetScheduleBlockIterVar(Var iter_var, Expr expr) {
IRContext cur_context =
IRBuilder::CurrentIRBuilder()
.data_->GetLastContext<ScheduleBlockContextNode>();
ScheduleBlockContextNode* cur_context_node =
cur_context.As<ScheduleBlockContextNode>();
cur_context_node->iter_vars.push_back(iter_var);
cur_context_node->iter_values.push_back(expr);
return iter_var.operator Expr();
}

Expr Arg(const std::string& name, Var var) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
var->name = name;
ctx.As<LowerFuncContextNode>()->args.emplace_back(var,
ir::Argument::IO::kUnknown);
return var.operator Expr();
}

Expr Arg(const std::string& name, ir::Buffer buffer) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
buffer->name = "_" + name;
// TODO(6clc): Unify cinn compilation and runtime Type,
// and add a Handle type to Var
ctx.As<LowerFuncContextNode>()->args.emplace_back(buffer,
ir::Argument::IO::kUnknown);
return buffer.operator Expr();
}

IRContext Sequential(Expr min, Expr extent) {
ForContextNode* for_ctx_node = new ForContextNode();
for_ctx_node->min = min;
for_ctx_node->extent = extent;
for_ctx_node->loop_var = ir::_Var_::Make("v", common::Int(32));
return IRContext(for_ctx_node);
}

} // namespace pybind

} // namespace cinn
Loading