Skip to content

Commit

Permalink
cinn(py-dsl): add ir context used in python dsl
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Sep 20, 2023
1 parent c96b9cb commit f282c04
Show file tree
Hide file tree
Showing 16 changed files with 757 additions and 50 deletions.
10 changes: 10 additions & 0 deletions paddle/cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name,
return LoweredFunc(n);
}

LoweredFunc _LoweredFunc_::Make(const std::string& name,
const std::vector<Argument>& args,
const Expr& body) {
auto* n = make_shared<_LoweredFunc_>();
n->name = name;
n->args = args;
n->body = body;
return LoweredFunc(n);
}

void _LoweredFunc_::CheckValid() const {
// check there is at least one output
int out_count = 0;
Expand Down
13 changes: 11 additions & 2 deletions paddle/cinn/ir/lowered_func.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ class _LoweredFunc_;
* the function signature of generated code.
*/
struct Argument {
//! Input or output.
enum class IO { kInput = 0, kOutput = 1 };
//! kInput: arg is input
//! kOutput: arg is output
//! kUnknown: arg maybe input or output
enum class IO { kInput = 0, kOutput = 1, kUnknown = 2 };

IO io{IO::kInput};

Expand Down Expand Up @@ -164,6 +166,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
const Expr& body,
const std::vector<ir::Buffer>& temp_bufs);

// A simple version of the make function method,
// regardless of the argument buffer information and IO information of
// Argument, after building the function to optimize the buffer through pass
static LoweredFunc Make(const std::string& name,
const std::vector<Argument>& args,
const Expr& body);

bool is_gpu_host() const { return cuda_axis_info.valid(); }

void Verify() const override {}
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ void Module::Builder::Clear() {
module_->submodules.clear();
}

Target::Arch Module::Builder::GetTarget() { return module_->target.arch; }

Module Module::Builder::Build() {
if (module_->functions.empty()) {
VLOG(1) << "Module has no functions";
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Module : public ir::IrNodeRef {
void AddFunctionWithoutOptim(const ir::LoweredFunc& func);
void AddBuffer(ir::Buffer buffer);
void Clear();
Target::Arch GetTarget();

Module Build();

Expand Down
17 changes: 17 additions & 0 deletions paddle/cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ Tensor _Tensor_::Make(const std::string &name,

return Tensor(n);
}
Tensor _Tensor_::Make(const std::string &name,
Type dtype,
const std::vector<Expr> &shape,
const std::vector<Expr> &domain,
const std::vector<Var> &reduce_axis) {
CHECK(!name.empty()) << "Cannot set empty Tensor name in Tensor::Make";
auto n = make_shared<_Tensor_>();
n->name = name;
n->shape = shape;
n->domain = domain;
n->reduce_axis = reduce_axis;
n->operation = PlaceholderOp::Make(n->name, n->shape, Float(32));
n->set_type(dtype);
n->InitAxis();

return Tensor(n);
}

size_t Tensor::ndims() const { return operator->()->shape.size(); }

Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ class _Tensor_ : public ExprNode<_Tensor_> {
FunctionRef fn,
const std::vector<Var>& reduce_axis = {});

// Manual tensor construction, no FunctionRef information
static Tensor Make(const std::string& name,
Type dtype,
const std::vector<Expr>& shape,
const std::vector<Expr>& domain,
const std::vector<Var>& reduce_axis = {});

void Verify() const override;

bool IsReduceInited(poly::StageMap stages) const;
Expand Down
86 changes: 57 additions & 29 deletions paddle/cinn/ir/utils/ir_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return true;
}

if (only_compare_sturcture_ && !lhs.defined() && !rhs.defined()) {
return true;
}

if (!lhs.defined() || !rhs.defined()) { // someone invalid
return false;
VLOG(5) << "Not equal on Expr, someone not defined";
Expand All @@ -44,10 +48,9 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return equal;
}

bool IrEqualVisitor::Compare(const std::string& lhs,
const std::string& rhs,
bool allow_name_suffix_diff) {
// if allow_name_suffix_diff=true then just compare the name prefix before the
bool IrEqualVisitor::Compare(const std::string& lhs, const std::string& rhs) {
// if allow_name_suffix_diff_=true then just compare the name prefix before
// the
// "_[0-9]+"
auto common_len = 0;
for (; common_len < lhs.size() && common_len < rhs.size(); ++common_len) {
Expand All @@ -65,7 +68,7 @@ bool IrEqualVisitor::Compare(const std::string& lhs,
equal = true;
} else {
equal = false;
if (allow_name_suffix_diff) {
if (allow_name_suffix_diff_) {
equal = is_endswith_index(lhs) && is_endswith_index(rhs);
}
}
Expand Down Expand Up @@ -179,17 +182,26 @@ bool IrEqualVisitor::Visit(const Block* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const Call* lhs, const Expr* other) {
auto* rhs = other->As<Call>();
return lhs->name == rhs->name && Compare(lhs->read_args, rhs->read_args) &&
Compare(lhs->write_args, rhs->write_args) &&
Compare(lhs->attrs, rhs->attrs) && lhs->call_type == rhs->call_type;
bool flag = Compare(lhs->read_args, rhs->read_args) &&
Compare(lhs->write_args, rhs->write_args) &&
Compare(lhs->attrs, rhs->attrs) &&
lhs->call_type == rhs->call_type;
if (only_compare_sturcture_) {
return flag;
}
return lhs->name == rhs->name && flag;
// TODO(CtfGo): Compare `func` field
}

bool IrEqualVisitor::Visit(const _Var_* lhs, const Expr* other) {
auto* rhs = other->As<_Var_>();
return lhs->name == rhs->name &&
Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) && lhs->tag == rhs->tag;
bool flag = Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) &&
lhs->tag == rhs->tag;
if (only_compare_sturcture_) {
return flag;
}
return lhs->name == rhs->name && flag;
}

bool IrEqualVisitor::Visit(const Load* lhs, const Expr* other) {
Expand Down Expand Up @@ -219,19 +231,25 @@ bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) {
auto* rhs = other->As<_Buffer_>();
return Compare(lhs->shape, rhs->shape) &&
Compare(lhs->strides, rhs->strides) && lhs->name == rhs->name &&
lhs->scope == rhs->scope &&
Compare(lhs->elem_offset, rhs->elem_offset) &&
lhs->offset_factor == rhs->offset_factor &&
lhs->target == rhs->target &&
lhs->data_alignment == rhs->data_alignment &&
lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype;
bool flag =
Compare(lhs->shape, rhs->shape) && Compare(lhs->strides, rhs->strides) &&
lhs->scope == rhs->scope && Compare(lhs->elem_offset, rhs->elem_offset) &&
lhs->offset_factor == rhs->offset_factor && lhs->target == rhs->target &&
lhs->data_alignment == rhs->data_alignment &&
lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype;
if (only_compare_sturcture_) {
return flag;
}
return flag && lhs->name == rhs->name;
}

bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) {
auto* rhs = other->As<_Tensor_>();
return lhs->name == rhs->name && Compare(lhs->shape, rhs->shape);
bool flag = Compare(lhs->shape, rhs->shape);
if (only_compare_sturcture_) {
return flag;
}
return flag && Compare(lhs->name, rhs->name);
}

bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {
Expand Down Expand Up @@ -280,10 +298,15 @@ bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const _Module_* lhs, const Expr* other) {
auto* rhs = other->As<_Module_>();
return lhs->name == rhs->name && lhs->target == rhs->target &&
Compare(lhs->buffers, rhs->buffers) &&
Compare(lhs->functions, rhs->functions) &&
Compare(lhs->submodules, rhs->submodules);
bool flag = Compare(lhs->buffers, rhs->buffers) &&
Compare(lhs->functions, rhs->functions) &&
Compare(lhs->submodules, rhs->submodules);

if (only_compare_sturcture_) {
return flag;
}

return flag && lhs->name == rhs->name;
}

bool IrEqualVisitor::Visit(const Let* lhs, const Expr* other) {
Expand Down Expand Up @@ -345,11 +368,16 @@ bool IrEqualVisitor::Visit(const _BufferRange_* lhs, const Expr* other) {

bool IrEqualVisitor::Visit(const ScheduleBlock* lhs, const Expr* other) {
auto* rhs = other->As<ScheduleBlock>();
return Compare(lhs->name, rhs->name, allow_name_suffix_diff_) &&
Compare(lhs->iter_vars, rhs->iter_vars) &&
Compare(lhs->read_buffers, rhs->read_buffers) &&
Compare(lhs->write_buffers, rhs->write_buffers) &&
Compare(lhs->attrs, rhs->attrs) && Compare(lhs->body, rhs->body);
bool flag = Compare(lhs->iter_vars, rhs->iter_vars) &&
Compare(lhs->read_buffers, rhs->read_buffers) &&
Compare(lhs->write_buffers, rhs->write_buffers) &&
Compare(lhs->body, rhs->body);

if (only_compare_sturcture_) {
return flag;
}
return flag && Compare(lhs->attrs, rhs->attrs) &&
Compare(lhs->name, rhs->name);
}

bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) {
Expand Down
12 changes: 7 additions & 5 deletions paddle/cinn/ir/utils/ir_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ namespace ir {
// fields of each node through dfs visitor
class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false)
: allow_name_suffix_diff_(allow_name_suffix_diff) {}
explicit IrEqualVisitor(bool allow_name_suffix_diff = false,
bool only_compare_sturcture = false)
: allow_name_suffix_diff_(allow_name_suffix_diff),
only_compare_sturcture_(only_compare_sturcture) {}
// Return true if they are euqal, otherwise false;
bool Compare(const Expr& lhs, const Expr& rhs);

private:
bool Compare(const std::string& lhs,
const std::string& rhs,
bool allow_name_suffix_diff = false);
bool Compare(const std::string& lhs, const std::string& rhs);
bool Compare(const std::map<std::string, attr_t>& lhs,
const std::map<std::string, attr_t>& rhs);
template <typename T>
Expand All @@ -45,6 +45,8 @@ class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {

// whether allowing name suffix ends with "_[0-9]+" different
bool allow_name_suffix_diff_ = false;
// not compare name field of Expr
bool only_compare_sturcture_ = false;
};

} // namespace ir
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ set(srcs
runtime.cc
common.cc
lang.cc
ir.cc
ir/ir.cc
ir/ir_api.cc
ir/ir_context.cc
poly.cc
backends.cc
bind.cc
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/pybind/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ void BindTarget(py::module *m) {
void BindType(py::module *m) {
py::class_<Type> type(*m, "Type");
type.def(py::init<>())
.def(py::init<Type &>())
.def(py::init<Type::type_t, int, int, Type::specific_type_t>());
#define DEFINE_TYPE_METHOD(__name) (type = type.def(#__name, &Type::__name))
DEFINE_TYPE_METHOD(is_primitive);
Expand Down Expand Up @@ -140,7 +141,7 @@ void BindType(py::module *m) {
.export_values();

py::enum_<Type::specific_type_t> specific_type_t(type, "specific_type_t");
specific_type_t.value("None", Type::specific_type_t::None)
specific_type_t.value("UNK", Type::specific_type_t::None)
.value("FP16", Type::specific_type_t::FP16)
.value("BF16", Type::specific_type_t::BF16)
.export_values();
Expand Down
98 changes: 98 additions & 0 deletions paddle/cinn/pybind/ir/ir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/pybind/ir/ir.h"
#include "paddle/cinn/pybind/ir/ir_context.h"
namespace cinn {
namespace pybind {
void TensorStore(Expr tensor, Expr value, const std::vector<Expr>& indices) {
// TODO(6clc): Check the compatibility of data types for tensor and value
IRContext find_sch_block =
IRBuilder::CurrentIRBuilder()
.data_->FindContext<ScheduleBlockContextNode>();
if (!find_sch_block.data_.defined()) {
IRContext sch_block(new ScheduleBlockContextNode());
sch_block.data_->EnterWithContext();
LinkToParentContext(ir::Store::Make(tensor, value, indices));
sch_block.data_->ExitWithContext();
return;
}
LinkToParentContext(ir::Store::Make(tensor, value, indices));
}
std::vector<Expr> AxisMap(const std::string& kinds,
const std::vector<Expr>& iter_expression) {
std::vector<Expr> rets;
CHECK_EQ(kinds.size(), iter_expression.size());
int n = iter_expression.size();
rets.reserve(n);
for (int i = 0; i < n; i++) {
char c = kinds.c_str()[i];

// TODO(6clc): set bound of IterVar

Var iter_var = ir::_Var_::Make("iter_tmp", common::Int(32));
if (c == 'S') {
iter_var->is_reduce_axis = false;
} else if (c == 'R') {
iter_var->is_reduce_axis = true;
} else {
LOG(FATAL)
<< "kind of axis setting error, must be R(Reduce) or S(Spatial)";
}
rets.push_back(SetScheduleBlockIterVar(iter_var, iter_expression[i]));
}
return rets;
}
Var SetScheduleBlockIterVar(Var iter_var, Expr expr) {
IRContext cur_context =
IRBuilder::CurrentIRBuilder()
.data_->GetLastContext<ScheduleBlockContextNode>();
ScheduleBlockContextNode* cur_context_node =
cur_context.As<ScheduleBlockContextNode>();
cur_context_node->iter_vars.push_back(iter_var);
cur_context_node->iter_values.push_back(expr);
return iter_var.operator Expr();
}

Expr Arg(const std::string& name, Var var) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
var->name = name;
ctx.As<LowerFuncContextNode>()->args.emplace_back(var,
ir::Argument::IO::kUnknown);
return var.operator Expr();
}

Expr Arg(const std::string& name, ir::Buffer buffer) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
buffer->name = "_" + name;
// TODO(6clc): Unify cinn compilation and runtime Type,
// and add a Handle type to Var
ctx.As<LowerFuncContextNode>()->args.emplace_back(buffer,
ir::Argument::IO::kUnknown);
return buffer.operator Expr();
}

IRContext Sequential(Expr min, Expr extent) {
ForContextNode* for_ctx_node = new ForContextNode();
for_ctx_node->min = min;
for_ctx_node->extent = extent;
for_ctx_node->loop_var = ir::_Var_::Make("v", common::Int(32));
return IRContext(for_ctx_node);
}

} // namespace pybind

} // namespace cinn
Loading

0 comments on commit f282c04

Please sign in to comment.