diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index ec5f4b2e64ce6b..410ac068df85db 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -64,6 +64,16 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name, return LoweredFunc(n); } +LoweredFunc _LoweredFunc_::Make(const std::string& name, + const std::vector& args, + const Expr& body) { + auto* n = make_shared<_LoweredFunc_>(); + n->name = name; + n->args = args; + n->body = body; + return LoweredFunc(n); +} + void _LoweredFunc_::CheckValid() const { // check there is at least one output int out_count = 0; diff --git a/paddle/cinn/ir/lowered_func.h b/paddle/cinn/ir/lowered_func.h old mode 100755 new mode 100644 index 03ffacad817bd7..b305f84506fe41 --- a/paddle/cinn/ir/lowered_func.h +++ b/paddle/cinn/ir/lowered_func.h @@ -30,8 +30,10 @@ class _LoweredFunc_; * the function signature of generated code. */ struct Argument { - //! Input or output. - enum class IO { kInput = 0, kOutput = 1 }; + //! kInput: arg is input + //! kOutput: arg is output + //! kUnknown: arg maybe input or output + enum class IO { kInput = 0, kOutput = 1, kUnknown = 2 }; IO io{IO::kInput}; @@ -164,6 +166,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { const Expr& body, const std::vector& temp_bufs); + // A simple version of the make function method, + // regardless of the argument buffer information and IO information of + // Argument, after building the function to optimize the buffer through pass + static LoweredFunc Make(const std::string& name, + const std::vector& args, + const Expr& body); + bool is_gpu_host() const { return cuda_axis_info.valid(); } void Verify() const override {} diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index d52ee148b8bc85..fa791dcdbcd62d 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -54,6 +54,8 @@ void Module::Builder::Clear() { module_->submodules.clear(); } +Target::Arch Module::Builder::GetTargetArch() { return module_->target.arch; } + Module Module::Builder::Build() { if (module_->functions.empty()) { VLOG(1) << "Module has no functions"; diff --git a/paddle/cinn/ir/module.h b/paddle/cinn/ir/module.h index 9d2b3610830714..a057c4862cc0e5 100644 --- a/paddle/cinn/ir/module.h +++ b/paddle/cinn/ir/module.h @@ -45,6 +45,7 @@ class Module : public ir::IrNodeRef { void AddFunctionWithoutOptim(const ir::LoweredFunc& func); void AddBuffer(ir::Buffer buffer); void Clear(); + Target::Arch GetTargetArch(); Module Build(); diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 8ad8b9878d4bc1..ca7147db69249c 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -53,6 +53,23 @@ Tensor _Tensor_::Make(const std::string &name, return Tensor(n); } +Tensor _Tensor_::Make(const std::string &name, + Type dtype, + const std::vector &shape, + const std::vector &domain, + const std::vector &reduce_axis) { + CHECK(!name.empty()) << "Cannot set empty Tensor name in Tensor::Make"; + auto n = make_shared<_Tensor_>(); + n->name = name; + n->shape = shape; + n->domain = domain; + n->reduce_axis = reduce_axis; + n->operation = PlaceholderOp::Make(n->name, n->shape, Float(32)); + n->set_type(dtype); + n->InitAxis(); + + return Tensor(n); +} size_t Tensor::ndims() const { return operator->()->shape.size(); } diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h index 5c252d35faceb5..56995559dba941 100644 --- a/paddle/cinn/ir/tensor.h +++ b/paddle/cinn/ir/tensor.h @@ -149,6 +149,13 @@ class _Tensor_ : public ExprNode<_Tensor_> { FunctionRef fn, const std::vector& reduce_axis = {}); + // Manual tensor construction, no FunctionRef information + static Tensor Make(const std::string& name, + Type dtype, + const std::vector& shape, + const std::vector& domain, + const std::vector& reduce_axis = {}); + void Verify() const override; bool IsReduceInited(poly::StageMap stages) const; diff --git a/paddle/cinn/ir/utils/ir_compare.cc b/paddle/cinn/ir/utils/ir_compare.cc index 87324be608048d..fbe7a65c43efca 100644 --- a/paddle/cinn/ir/utils/ir_compare.cc +++ b/paddle/cinn/ir/utils/ir_compare.cc @@ -29,6 +29,10 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { return true; } + if (only_compare_structure_ && !lhs.defined() && !rhs.defined()) { + return true; + } + if (!lhs.defined() || !rhs.defined()) { // someone invalid return false; VLOG(5) << "Not equal on Expr, someone not defined"; @@ -46,10 +50,9 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { return equal; } -bool IrEqualVisitor::Compare(const std::string& lhs, - const std::string& rhs, - bool allow_name_suffix_diff) { - // if allow_name_suffix_diff=true then just compare the name prefix before the +bool IrEqualVisitor::Compare(const std::string& lhs, const std::string& rhs) { + // if allow_name_suffix_diff_=true then just compare the name prefix before + // the // "_[0-9]+" auto common_len = 0; for (; common_len < lhs.size() && common_len < rhs.size(); ++common_len) { @@ -67,7 +70,7 @@ bool IrEqualVisitor::Compare(const std::string& lhs, equal = true; } else { equal = false; - if (allow_name_suffix_diff) { + if (allow_name_suffix_diff_) { equal = is_endswith_index(lhs) && is_endswith_index(rhs); } } @@ -181,17 +184,26 @@ bool IrEqualVisitor::Visit(const Block* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const Call* lhs, const Expr* other) { auto* rhs = other->As(); - return lhs->name == rhs->name && Compare(lhs->read_args, rhs->read_args) && - Compare(lhs->write_args, rhs->write_args) && - Compare(lhs->attrs, rhs->attrs) && lhs->call_type == rhs->call_type; + bool flag = Compare(lhs->read_args, rhs->read_args) && + Compare(lhs->write_args, rhs->write_args) && + Compare(lhs->attrs, rhs->attrs) && + lhs->call_type == rhs->call_type; + if (only_compare_structure_) { + return flag; + } + return lhs->name == rhs->name && flag; // TODO(CtfGo): Compare `func` field } bool IrEqualVisitor::Visit(const _Var_* lhs, const Expr* other) { auto* rhs = other->As<_Var_>(); - return lhs->name == rhs->name && - Compare(lhs->lower_bound, rhs->lower_bound) && - Compare(lhs->upper_bound, rhs->upper_bound) && lhs->tag == rhs->tag; + bool flag = Compare(lhs->lower_bound, rhs->lower_bound) && + Compare(lhs->upper_bound, rhs->upper_bound) && + lhs->tag == rhs->tag; + if (only_compare_structure_) { + return flag; + } + return lhs->name == rhs->name && flag; } bool IrEqualVisitor::Visit(const Load* lhs, const Expr* other) { @@ -221,19 +233,25 @@ bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) { auto* rhs = other->As<_Buffer_>(); - return Compare(lhs->shape, rhs->shape) && - Compare(lhs->strides, rhs->strides) && lhs->name == rhs->name && - lhs->scope == rhs->scope && - Compare(lhs->elem_offset, rhs->elem_offset) && - lhs->offset_factor == rhs->offset_factor && - lhs->target == rhs->target && - lhs->data_alignment == rhs->data_alignment && - lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype; + bool flag = + Compare(lhs->shape, rhs->shape) && Compare(lhs->strides, rhs->strides) && + lhs->scope == rhs->scope && Compare(lhs->elem_offset, rhs->elem_offset) && + lhs->offset_factor == rhs->offset_factor && lhs->target == rhs->target && + lhs->data_alignment == rhs->data_alignment && + lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype; + if (only_compare_structure_) { + return flag; + } + return flag && lhs->name == rhs->name; } bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) { auto* rhs = other->As<_Tensor_>(); - return lhs->name == rhs->name && Compare(lhs->shape, rhs->shape); + bool flag = Compare(lhs->shape, rhs->shape); + if (only_compare_structure_) { + return flag; + } + return flag && Compare(lhs->name, rhs->name); } bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) { @@ -282,10 +300,15 @@ bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const _Module_* lhs, const Expr* other) { auto* rhs = other->As<_Module_>(); - return lhs->name == rhs->name && lhs->target == rhs->target && - Compare(lhs->buffers, rhs->buffers) && - Compare(lhs->functions, rhs->functions) && - Compare(lhs->submodules, rhs->submodules); + bool flag = Compare(lhs->buffers, rhs->buffers) && + Compare(lhs->functions, rhs->functions) && + Compare(lhs->submodules, rhs->submodules); + + if (only_compare_structure_) { + return flag; + } + + return flag && lhs->name == rhs->name; } bool IrEqualVisitor::Visit(const Let* lhs, const Expr* other) { @@ -347,11 +370,16 @@ bool IrEqualVisitor::Visit(const _BufferRange_* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const ScheduleBlock* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->name, rhs->name, allow_name_suffix_diff_) && - Compare(lhs->iter_vars, rhs->iter_vars) && - Compare(lhs->read_buffers, rhs->read_buffers) && - Compare(lhs->write_buffers, rhs->write_buffers) && - Compare(lhs->attrs, rhs->attrs) && Compare(lhs->body, rhs->body); + bool flag = Compare(lhs->iter_vars, rhs->iter_vars) && + Compare(lhs->read_buffers, rhs->read_buffers) && + Compare(lhs->write_buffers, rhs->write_buffers) && + Compare(lhs->body, rhs->body); + + if (only_compare_structure_) { + return flag; + } + return flag && Compare(lhs->attrs, rhs->attrs) && + Compare(lhs->name, rhs->name); } bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) { diff --git a/paddle/cinn/ir/utils/ir_compare.h b/paddle/cinn/ir/utils/ir_compare.h index d41e6db0441a7b..03ec82c2467507 100644 --- a/paddle/cinn/ir/utils/ir_compare.h +++ b/paddle/cinn/ir/utils/ir_compare.h @@ -26,15 +26,15 @@ namespace ir_utils { // fields of each node through dfs visitor class IrEqualVisitor : public IRVisitorRequireReImpl { public: - explicit IrEqualVisitor(bool allow_name_suffix_diff = false) - : allow_name_suffix_diff_(allow_name_suffix_diff) {} + explicit IrEqualVisitor(bool allow_name_suffix_diff = false, + bool only_compare_structure = false) + : allow_name_suffix_diff_(allow_name_suffix_diff), + only_compare_structure_(only_compare_structure) {} // Return true if they are euqal, otherwise false; bool Compare(const Expr& lhs, const Expr& rhs); private: - bool Compare(const std::string& lhs, - const std::string& rhs, - bool allow_name_suffix_diff = false); + bool Compare(const std::string& lhs, const std::string& rhs); bool Compare(const std::map& lhs, const std::map& rhs); template @@ -46,6 +46,8 @@ class IrEqualVisitor : public IRVisitorRequireReImpl { // whether allowing name suffix ends with "_[0-9]+" different bool allow_name_suffix_diff_ = false; + // not compare name field of Expr + bool only_compare_structure_ = false; }; bool IRCompare(const Expr& lhs, diff --git a/paddle/cinn/pybind/CMakeLists.txt b/paddle/cinn/pybind/CMakeLists.txt index bf6e3d095377fc..c00a64614f6430 100755 --- a/paddle/cinn/pybind/CMakeLists.txt +++ b/paddle/cinn/pybind/CMakeLists.txt @@ -2,7 +2,9 @@ set(srcs runtime.cc common.cc lang.cc - ir.cc + ir/ir.cc + ir/ir_api.cc + ir/ir_context.cc poly.cc backends.cc bind.cc diff --git a/paddle/cinn/pybind/common.cc b/paddle/cinn/pybind/common.cc index 170ebfc6d69164..bdb4b46c848ef9 100644 --- a/paddle/cinn/pybind/common.cc +++ b/paddle/cinn/pybind/common.cc @@ -94,6 +94,7 @@ void BindTarget(py::module *m) { void BindType(py::module *m) { py::class_ type(*m, "Type"); type.def(py::init<>()) + .def(py::init()) .def(py::init()); #define DEFINE_TYPE_METHOD(__name) (type = type.def(#__name, &Type::__name)) DEFINE_TYPE_METHOD(is_primitive); @@ -140,7 +141,7 @@ void BindType(py::module *m) { .export_values(); py::enum_ specific_type_t(type, "specific_type_t"); - specific_type_t.value("None", Type::specific_type_t::None) + specific_type_t.value("UNK", Type::specific_type_t::None) .value("FP16", Type::specific_type_t::FP16) .value("BF16", Type::specific_type_t::BF16) .export_values(); diff --git a/paddle/cinn/pybind/ir/ir.cc b/paddle/cinn/pybind/ir/ir.cc new file mode 100644 index 00000000000000..f569bd2c973bee --- /dev/null +++ b/paddle/cinn/pybind/ir/ir.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/pybind/ir/ir.h" +#include "paddle/cinn/pybind/ir/ir_context.h" +namespace cinn { +namespace pybind { +void TensorStore(Expr tensor, Expr value, const std::vector& indices) { + // TODO(6clc): Check the compatibility of data types for tensor and value + IRContext find_sch_block = + IRBuilder::CurrentIRBuilder() + .data_->FindContext(); + if (!find_sch_block.data_.defined()) { + IRContext sch_block(new ScheduleBlockContextNode()); + sch_block.data_->EnterWithContext(); + LinkToParentContext(ir::Store::Make(tensor, value, indices)); + sch_block.data_->ExitWithContext(); + return; + } + LinkToParentContext(ir::Store::Make(tensor, value, indices)); +} +std::vector AxisMap(const std::string& kinds, + const std::vector& iter_expression) { + std::vector rets; + CHECK_EQ(kinds.size(), iter_expression.size()); + int n = iter_expression.size(); + rets.reserve(n); + for (int i = 0; i < n; i++) { + char c = kinds.c_str()[i]; + + // TODO(6clc): set bound of IterVar + + Var iter_var = ir::_Var_::Make("iter_tmp", common::Int(32)); + if (c == 'S') { + iter_var->is_reduce_axis = false; + } else if (c == 'R') { + iter_var->is_reduce_axis = true; + } else { + LOG(FATAL) + << "kind of axis setting error, must be R(Reduce) or S(Spatial)"; + } + rets.push_back(SetScheduleBlockIterVar(iter_var, iter_expression[i])); + } + return rets; +} +Var SetScheduleBlockIterVar(Var iter_var, Expr expr) { + IRContext cur_context = + IRBuilder::CurrentIRBuilder() + .data_->GetLastContext(); + ScheduleBlockContextNode* cur_context_node = + cur_context.As(); + cur_context_node->iter_vars.push_back(iter_var); + cur_context_node->iter_values.push_back(expr); + return iter_var.operator Expr(); +} + +Expr Arg(const std::string& name, Var var) { + IRContext ctx = + IRBuilder::CurrentIRBuilder().data_->FindContext(); + var->name = name; + ctx.As()->args.emplace_back(var, + ir::Argument::IO::kUnknown); + return var.operator Expr(); +} + +Expr Arg(const std::string& name, ir::Buffer buffer) { + IRContext ctx = + IRBuilder::CurrentIRBuilder().data_->FindContext(); + buffer->name = "_" + name; + // TODO(6clc): Unify cinn compilation and runtime Type, + // and add a Handle type to Var + ctx.As()->args.emplace_back(buffer, + ir::Argument::IO::kUnknown); + return buffer.operator Expr(); +} + +IRContext Sequential(Expr min, Expr extent) { + ForContextNode* for_ctx_node = new ForContextNode(); + for_ctx_node->min = min; + for_ctx_node->extent = extent; + for_ctx_node->loop_var = ir::_Var_::Make("v", common::Int(32)); + return IRContext(for_ctx_node); +} + +} // namespace pybind + +} // namespace cinn diff --git a/paddle/cinn/pybind/ir/ir.h b/paddle/cinn/pybind/ir/ir.h new file mode 100644 index 00000000000000..9a4e2e2263f0ed --- /dev/null +++ b/paddle/cinn/pybind/ir/ir.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/pybind/ir/ir_context.h" +namespace cinn { +namespace pybind { + +template IRContext IRBuilderNode::GetLastContext() + const; +Var SetScheduleBlockIterVar(Var iter_var, Expr expr); +std::vector AxisMap(const std::string &kinds, + const std::vector &iter_expression); +void TensorStore(Expr tensor, Expr value, const std::vector &indices); +Expr Arg(const std::string &name, Var var); +Expr Arg(const std::string &name, ir::Buffer buffer); +IRContext Sequential(Expr min, Expr extent); +} // namespace pybind +} // namespace cinn diff --git a/paddle/cinn/pybind/ir.cc b/paddle/cinn/pybind/ir/ir_api.cc similarity index 85% rename from paddle/cinn/pybind/ir.cc rename to paddle/cinn/pybind/ir/ir_api.cc index b03b7181509d8c..66c0e2306d8cc9 100644 --- a/paddle/cinn/pybind/ir.cc +++ b/paddle/cinn/pybind/ir/ir_api.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/ir/ir.h" - #include #include #include @@ -22,21 +20,29 @@ #include #include +#include "paddle/cinn/common/shared.h" +#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/operation.h" #include "paddle/cinn/ir/registry.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_compare.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_visitor.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/pybind/bind_utils.h" +#include "paddle/cinn/pybind/ir/ir.h" +#include "paddle/cinn/pybind/ir/ir_context.h" namespace py = pybind11; +PYBIND11_DECLARE_HOLDER_TYPE(T, cinn::common::Shared); + namespace cinn::pybind { using ir::IrNode; using ir::IrNodeRef; @@ -62,7 +68,8 @@ void BindLoweredFunc(py::module *m) { py::enum_ io(argument, "IO"); io.value("kInput", Argument::IO::kInput) - .value("kOutput", Argument::IO::kOutput); + .value("kOutput", Argument::IO::kOutput) + .value("kUnknown", Argument::IO::kUnknown); argument .def(py::init(), @@ -93,10 +100,12 @@ void BindLoweredFunc(py::module *m) { [](const ir::LoweredFunc &self) -> std::string { return utils::GetStreamCnt(Expr(self)); }) - .def("__repr__", [](const ir::LoweredFunc &self) -> std::string { - return llvm::formatv( - "", self.get(), self->name.c_str()); - }); + .def("__repr__", + [](const ir::LoweredFunc &self) -> std::string { + return llvm::formatv( + "", self.get(), self->name.c_str()); + }) + .def("body", [](const ir::LoweredFunc &self) { return self->body; }); } void BindNode(py::module *m) { @@ -258,6 +267,13 @@ void BindNode(py::module *m) { // empty visitor void BindIrVisitor(py::module *m) { + py::class_ ir_compare(*m, "IrCompare"); + ir_compare.def(py::init()) + .def("compare", + [](ir::ir_utils::IrEqualVisitor &self, + const cinn::ir::Expr &lhs, + const cinn::ir::Expr &rhs) { return self.Compare(lhs, rhs); }); + py::class_ ir_visitor(*m, "IRVisitor"); ir_visitor.def(py::init<>()) .def("visit", py::overload_cast(&ir::IRVisitor::Visit)); @@ -466,6 +482,7 @@ void BindIrIr(py::module *m) { .def(py::init()) .def(py::init()) .def(py::init()) + .def("rename", [](Var &self, std::string &name) { self->name = name; }) .def("get_mutable", py::overload_cast<>(&Var::get), py::return_value_policy::reference) @@ -537,6 +554,31 @@ void BindIrIr(py::module *m) { .def_readwrite("buffers", &ir::_Module_::buffers) .def_readwrite("functions", &ir::_Module_::functions) .def_readwrite("submodules", &ir::_Module_::submodules); + + DefineExprNode(m, "_Buffer_"); + py::class_> _buffer_(*m, "_Buffer_"); + _buffer_ + .def_static( + "make", + py::overload_cast(&ir::_Buffer_::Make)) + .def_static( + "make", + py::overload_cast &>( + &ir::_Buffer_::Make)); + py::class_ buffer(*m, "Buffer"); + buffer.def(py::init<>()); + + py::class_ module_expr(*m, "ModuleExpr"); + module_expr.def(py::init &>()); + + DefineExprNode(m, "IfThenElse"); + py::class_ if_then_else(*m, "IfThenElse"); + if_then_else.def_static( + "make", + py::overload_cast(&ir::IfThenElse::Make), + py::arg("condition"), + py::arg("true_case"), + py::arg("false_case") = ir::Expr()); } void BindOperation(py::module *m) { @@ -586,9 +628,24 @@ void BindIrTensor(py::module *m) { [](ir::Tensor &self, Expr a, Expr b, Expr c) { return self(a, b, c); }) - .def("__call__", [](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) { - return self(a, b, c, d); - }); + .def("__call__", + [](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) { + return self(a, b, c, d); + }) + .def("__getitem__", [](ir::Tensor &self, Expr a) { return self(a); }) + .def("__getitem__", + [](ir::Tensor &self, Expr a, Expr b) { return self(a, b); }) + .def("__getitem__", + [](ir::Tensor &self, Expr a, Expr b, Expr c) { + return self(a, b, c); + }) + .def("__getitem__", + [](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) { + return self(a, b, c, d); + }) + .def("__getitem__", + [](ir::Tensor &self, std::vector idx) { return self(idx); }) + .def("Expr", [](ir::Tensor &self) { return self.operator Expr(); }); DefineExprNode(m, "_Tensor_"); py::class_> _tensor_(*m, "_Tensor_"); @@ -600,7 +657,18 @@ void BindIrTensor(py::module *m) { .def("domain_with_reduce_axis", &ir::_Tensor_::domain_without_reduce_axis) .def("domain_without_reduce_axis", &ir::_Tensor_::domain_without_reduce_axis) - .def_static("make", &ir::_Tensor_::Make) + .def_static( + "make", + py::overload_cast &, + const std::vector &, + const std::vector &>(&ir::_Tensor_::Make), + py::arg("name"), + py::arg("dtype"), + py::arg("shape"), + py::arg("domain"), + py::arg("reduce_axis") = std::vector({})) .def("is_tuple", &ir::_Tensor_::is_tuple) .def("is_tuple_get", &ir::_Tensor_::is_tuple_get) .def("tuple_get", &ir::_Tensor_::TupleGet) @@ -741,6 +809,54 @@ void BindRegistry(py::module *m) { }); #endif } + +void BindIrContext(py::module *m) { + using ir::Expr; + using ir::IrNode; + using ir::IrNodeRef; + using ir::Var; + using py::arg; + + py::class_ ir_ctx(*m, "IRContext"); + ir_ctx.def(py::init<>()) + .def(py::init()) + .def("EnterWithContext", + [](IRContext &self) { self.data_->EnterWithContext(); }) + .def("ExitWithContext", + [](IRContext &self) { self.data_->ExitWithContext(); }) + .def("get_for_loop_var", + [](IRContext &self) { + return self.data_->safe_as()->loop_var; + }) + .def_static("MakeLowerFunctionContext", + [](std::string &name) { + return IRContext(new LowerFuncContextNode(name)); + }) + .def_static("MakeScheduleBlockContext", + [](std::string &name) { + return IRContext(new ScheduleBlockContextNode(name)); + }) + .def_static("MakeIfContext", + [](Expr expr) { return IRContext(new IfContextNode(expr)); }) + .def_static("MakeElseContext", + []() { return IRContext(new ElseContextNode()); }) + .def_static("MakeThenContext", + []() { return IRContext(new ThenContextNode()); }); + + py::class_ ir_builder(*m, "IRBuilder"); + ir_builder.def(py::init<>()) + .def("EnterWithContext", &IRBuilder::EnterWithContext) + .def("ExitWithContext", &IRBuilder::ExitWithContext) + .def("get_result", [](IRBuilder &self) { + return self.data_->GetResult().as_lowered_func_ref(); + }); + + m->def("AxisMap", &AxisMap); + m->def("TensorStore", &TensorStore); + m->def("Arg", py::overload_cast(&Arg)); + m->def("Arg", py::overload_cast(&Arg)); + m->def("Sequential", py::overload_cast(&Sequential)); +} } // namespace void BindIr(py::module *m) { @@ -750,6 +866,7 @@ void BindIr(py::module *m) { BindIrVisitor(m); BindIrIr(m); BindIrTensor(m); + BindIrContext(m); BindPackedFunc(m); BindRegistry(m); } diff --git a/paddle/cinn/pybind/ir/ir_context.cc b/paddle/cinn/pybind/ir/ir_context.cc new file mode 100644 index 00000000000000..8af89d974222f1 --- /dev/null +++ b/paddle/cinn/pybind/ir/ir_context.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/pybind/ir/ir_context.h" +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace pybind { +void IRContextNode::EnterWithContext() { + IRBuilder::CurrentIRBuilder().data_->contexts.emplace_back(this); +} +void IRContextNode::ExitWithContext() { + IRBuilder::CurrentIRBuilder().data_->contexts.pop_back(); +} + +void ScheduleBlockContextNode::ExitWithContext() { + IRContextNode::ExitWithContext(); + ir::Expr schedule_block = ir::ScheduleBlock::Make( + iter_vars, read_buffers, write_buffers, name, ir::Block::Make(exprs)); + + ir::Expr schedule_block_realize = + ir::ScheduleBlockRealize::Make(iter_values, schedule_block); + LinkToParentContext(schedule_block_realize); +} + +void ForContextNode::ExitWithContext() { + IRContextNode::ExitWithContext(); + LinkToParentContext(ir::For::Make(loop_var, + min, + extent, + ir::ForType::Serial, + ir::DeviceAPI::UNK, + ir::Block::Make(exprs))); +} + +void LowerFuncContextNode::ExitWithContext() { + IRContextNode::ExitWithContext(); + // TODO(6clc): implement Private Fields for intrinstic function, like + // allreduce + Expr body = ir::ScheduleBlockRealize::Make( + {}, ir::ScheduleBlock::Make({}, {}, {}, "root", ir::Block::Make(exprs))); + ir::LoweredFunc lower_func = + ir::_LoweredFunc_::Make(name, args, ir::Block::Make({body})); + IRBuilder ir_builder = IRBuilder::CurrentIRBuilder(); + ir_builder.data_->result = lower_func.operator Expr(); +} + +void IfContextNode::ExitWithContext() { + IRContextNode::ExitWithContext(); + if (!exprs.empty()) { + LOG(FATAL) << "Expr not be either in ThenBlock or ElseBlock in if"; + } + if (!true_case.defined()) { + LOG(FATAL) << "Expr not be defined in ThenBlock"; + } + LinkToParentContext(ir::IfThenElse::Make(condition, true_case, false_case)); +} + +void ThenContextNode::ExitWithContext() { + IRContextNode::ExitWithContext(); + IRContext for_ctx = + IRBuilder::CurrentIRBuilder().data_->GetLastContext(); + for_ctx.data_->safe_as()->true_case = ir::Block::Make(exprs); +} + +void ElseContextNode::ExitWithContext() { + IRContextNode::ExitWithContext(); + IRContext for_ctx = + IRBuilder::CurrentIRBuilder().data_->GetLastContext(); + for_ctx.data_->safe_as()->false_case = ir::Block::Make(exprs); +} + +Expr IRBuilderNode::GetResult() const { + CHECK(result.defined()) << "No result generated in IRBuilder"; + return result; +} + +void IRBuilderNode::Reset() { + contexts.clear(); + result.Reset(); +} + +IRBuilder::IRBuilder() { + common::Shared n(new IRBuilderNode()); + n->Reset(); + data_ = n; +} + +void IRBuilder::EnterWithContext() { + CHECK(data_->contexts.empty()) + << "There are still Contexts in IRBuilder that has not been fully " + "converted. Please build a new IR with the new IRbuilder"; + data_->result.Reset(); + std::vector* st = IRBuilderStack(); + st->push_back(*this); +} + +void IRBuilder::ExitWithContext() { + std::vector* st = IRBuilderStack(); + CHECK(!st->empty()); + st->pop_back(); +} +IRBuilder IRBuilder::CurrentIRBuilder() { + std::vector* st = IRBuilderStack(); + CHECK(!st->empty()) << "No IRBuilder Found"; + return st->back(); +} +std::vector* IRBuilderStack() { + thread_local std::vector stack; + return &stack; +} +void LinkToParentContext(ir::Expr expr) { + IRBuilder ir_builder = IRBuilder::CurrentIRBuilder(); + if (ir_builder.data_->contexts.empty()) { + ir_builder.data_->result = expr; + } else { + IRContext ir_context = ir_builder.data_->contexts.back(); + ir_context.add_expr(expr); + } +} + +} // namespace pybind +} // namespace cinn diff --git a/paddle/cinn/pybind/ir/ir_context.h b/paddle/cinn/pybind/ir/ir_context.h new file mode 100644 index 00000000000000..c96c423bb071e0 --- /dev/null +++ b/paddle/cinn/pybind/ir/ir_context.h @@ -0,0 +1,256 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/cinn/common/object.h" +#include "paddle/cinn/common/shared.h" +#include "paddle/cinn/common/type.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/lowered_func.h" +#include "paddle/cinn/utils/error.h" + +namespace cinn { +namespace pybind { + +/** + * A base context that represents the CINN IR that need context information + */ +class IRContextNode : public common::Object { + public: + std::vector exprs; + + public: + // Corresponds to the __enter__ method in python's context manager + virtual void EnterWithContext(); + // Corresponds to the __exit__ method in python's context manager + virtual void ExitWithContext(); + const char* type_info() const override { return __type_info__; } + + public: + static constexpr char* __type_info__ = "IRContextNode"; +}; + +/** + * The lifecycle of RAII resource management for IRContextNode + * is determined at the Python. + */ +class IRContext { + public: + IRContext() = default; + IRContext(const IRContext& other) = default; + explicit IRContext(IRContextNode* x) : data_(x) {} + + const IRContextNode* get() const { return data_.get(); } + const IRContextNode* operator->() const { return data_.get(); } + + void add_expr(Expr expr) { data_->exprs.push_back(expr); } + + public: + common::Shared data_; + + public: + template + const TIRContextNode* As() const { + static_assert(std::is_base_of()); + CHECK(data_.get()) << "IrContext holds null"; + auto* ctx_node = data_.get()->safe_as(); + if (!ctx_node) { + std::stringstream err_msg; + err_msg << "TypeConvertError: convert " << data_.get()->type_info() + << " to " << TIRContextNode::__type_info__; + + CINN_THROW(err_msg.str()); + } + return ctx_node; + } + template + TIRContextNode* As() { + CHECK(data_.get()) << "IrContext holds null"; + auto* ctx_node = data_.get()->safe_as(); + if (!ctx_node) { + LOG(FATAL) << "TypeConvertError: convert " << data_.get()->type_info() + << " to " << TIRContextNode::__type_info__; + } + return ctx_node; + } +}; + +class ScheduleBlockContextNode : public IRContextNode { + public: + std::vector iter_vars; + // BufferRange(s) which is read in this schedule block, it is used to + // analyze, not a real computation expression. Must be AST DFS order. + std::vector read_buffers; + // BufferRange(s) which is written in this schedule block, it is used to + // analyze, not a real computation expression. Must be AST DFS order. + std::vector write_buffers; + // Additional attributes about this schedulable block, + // which take some auxiliary hints for future transformations. + std::map attrs; + // values of the iter_vars + std::vector iter_values; + std::string name; + + public: + ScheduleBlockContextNode() = default; + explicit ScheduleBlockContextNode(std::string name) : name(name) {} + void ExitWithContext() final; + const char* type_info() const override { return __type_info__; } + + public: + static constexpr const char* __type_info__ = "ScheduleBlockContextNode"; +}; + +class ScheduleBlockContext : public IRContext { + public: + explicit ScheduleBlockContext(ScheduleBlockContextNode* x) : IRContext(x) {} +}; + +class ForContextNode : public IRContextNode { + public: + //! The loop variable. + Var loop_var; + //! The minimum value of the iteration. + Expr min; + //! The extent of the iteration. + Expr extent; + + public: + void ExitWithContext() final; + const char* type_info() const override { return __type_info__; } + + public: + static constexpr const char* __type_info__ = "ForContextNode"; +}; + +class LowerFuncContextNode : public IRContextNode { + public: + //! The name of this function. + std::string name; + //! The Arguments used in the body of the function. + std::vector args; + + public: + LowerFuncContextNode() = default; + explicit LowerFuncContextNode(std::string name) : name(name) {} + void ExitWithContext() final; + const char* type_info() const override { return __type_info__; } + + public: + static constexpr const char* __type_info__ = "LowerFuncContextNode"; +}; + +class IfContextNode : public IRContextNode { + public: + Expr condition; + Expr true_case; + Expr false_case; + + public: + IfContextNode() = default; + explicit IfContextNode(Expr condition) + : condition(condition), true_case(Expr()), false_case(Expr()) {} + const char* type_info() const override { return __type_info__; } + + void ExitWithContext() final; + + public: + static constexpr const char* __type_info__ = "IfContextNode"; +}; + +class ThenContextNode : public IRContextNode { + public: + ThenContextNode() = default; + const char* type_info() const override { return __type_info__; } + + void ExitWithContext() final; + + public: + static constexpr const char* __type_info__ = "ThenContextNode"; +}; + +class ElseContextNode : public IRContextNode { + public: + ElseContextNode() = default; + const char* type_info() const override { return __type_info__; } + void ExitWithContext() final; + + public: + static constexpr const char* __type_info__ = "ElseContextNode"; +}; + +/** + * A stack used to store current IRContext + */ +class IRBuilderNode : public common::Object { + public: + std::vector contexts; + Expr result; + const char* type_info() const override { return __type_info__; } + Expr GetResult() const; + void Reset(); + + template + IRContext GetLastContext() const; + + template + IRContext FindContext() const; + + public: + static constexpr const char* __type_info__ = "IRBuilderNode"; +}; + +/** + * The lifecycle of RAII resource management for IRBuilderNode + * is determined at the Python. + */ +class IRBuilder { + public: + IRBuilder(); + void EnterWithContext(); + void ExitWithContext(); + static IRBuilder CurrentIRBuilder(); + + public: + common::Shared data_; +}; + +std::vector* IRBuilderStack(); +void LinkToParentContext(ir::Expr); + +template +IRContext IRBuilderNode::GetLastContext() const { + if (!(contexts.back().As())) { + LOG(FATAL) << "TypeError: The last context is not " + << TIRContextNode::__type_info__; + } + return contexts.back(); +} + +template +IRContext IRBuilderNode::FindContext() const { + for (auto it = contexts.rbegin(); it != contexts.rend(); ++it) { + if (const TIRContextNode* p = it->As()) { + return *it; + } + } + return IRContext(); +} + +} // namespace pybind + +} // namespace cinn diff --git a/paddle/cinn/runtime/cinn_runtime.h b/paddle/cinn/runtime/cinn_runtime.h old mode 100755 new mode 100644 index 39ed8cbe5ee09c..17b5a400fd122b --- a/paddle/cinn/runtime/cinn_runtime.h +++ b/paddle/cinn/runtime/cinn_runtime.h @@ -128,7 +128,8 @@ typedef enum cinn_device_kind_t { cinn_unk_device = -1, // Undefined device. cinn_x86_device = 0, // X86 device cinn_opencl_device = 1, // OpenCL device - cinn_arm_device = 2 // ARM device + cinn_arm_device = 2, // ARM device + cinn_nvgpu_device = 3 // NVIDIA GPU device } cinn_device_kind_t; //! Help to tell where the buffer locates.