diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc index d10560209e6ae..d3ba3226e10ea 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -18,34 +18,114 @@ #include "paddle/cinn/ir/operation.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_printer.h" +#include "paddle/cinn/lang/compute.h" namespace cinn { namespace ast_gen_ius { -ir::Expr AstGen::Build(const ir::Tensor& tensor) { +ir::Expr ConvertReduceBody(ir::Expr body, + ir::Tensor tensor, + const std::vector& axis_exprs) { + ir::Reduce* reduce_node = body.As(); + if (!reduce_node) { + return ir::Store::Make(tensor, body, axis_exprs); + } + + switch (reduce_node->reduce_type) { + case ir::Reduce::kSum: + return ir::Store::Make( + tensor, tensor(axis_exprs) + reduce_node->body, axis_exprs); + case ir::Reduce::kMul: + return ir::Store::Make( + tensor, tensor(axis_exprs) * reduce_node->body, axis_exprs); + case ir::Reduce::kMax: + return ir::Store::Make( + tensor, + ir::Max::Make(tensor(axis_exprs), reduce_node->body), + axis_exprs); + case ir::Reduce::kMin: + return ir::Store::Make( + tensor, + ir::Min::Make(tensor(axis_exprs), reduce_node->body), + axis_exprs); + case ir::Reduce::kAll: + return ir::Store::Make( + tensor, tensor(axis_exprs) && reduce_node->body, axis_exprs); + case ir::Reduce::kAny: + return ir::Store::Make( + tensor, tensor(axis_exprs) || reduce_node->body, axis_exprs); + default: + CINN_NOT_IMPLEMENTED + } +} + +ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { const std::vector& axis = tensor->axis(); const std::vector& shape = tensor->shape; size_t axis_len = axis.size(); - CHECK_EQ(shape.size(), axis_len) - << "Internal Error: Tensor has different shape and axis length in AstGen"; - + CHECK_EQ(shape.size(), axis_len) << "Internal Error: Tensor has different " + "shape and axis length in AstGen"; std::vector axis_exprs; for (const auto& a : axis) { axis_exprs.push_back(a); } - ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs); - - for (int i = static_cast(axis_len) - 1; i >= 0; --i) { - ir::Var loop_var = axis[i]; - ir::Expr loop_extent = shape[i]; - body = ir::For::Make(loop_var, - Expr(0), - loop_extent, - ir::ForType::Serial, - ir::DeviceAPI::Host, - ir::Block::Make({body})); + + if (tensor->is_reduce_tensor()) { + // Make an init Tensor for domain without reduce axis + Expr init_value = tensor->GetReduceInitVal(); + // TODO(zhhsplendid): Clean the handcoded "__reduce_init" string + std::string reduce_init_name = tensor->name + "__reduce_init"; + const std::vector& domain = tensor->domain_without_reduce_axis(); + ir::Tensor init_tensor = lang::Compute( + domain, + [=](const std::vector& axis) { return init_value; }, + reduce_init_name); + tensor_group->Insert(init_tensor); + tensor_group->MarkShareMemBuffer(tensor, init_tensor); + tensor_group->CtrlDepend(tensor, init_tensor); + Expr init_body = ir::Store::Make(init_tensor, init_value, axis_exprs); + + // For the remaining reduce axis, make reduce body + const std::vector& reduce_axis = tensor->reduce_axis; + ir::Expr reduce_body = + ConvertReduceBody(tensor->body(), tensor, axis_exprs); + for (int i = static_cast(reduce_axis.size()) - 1; i >= 0; --i) { + reduce_body = ir::For::Make(reduce_axis[i], + reduce_axis[i]->lower_bound, + reduce_axis[i]->upper_bound, + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({reduce_body})); + } + + // Put the two parts together + ir::Expr body = ir::Block::Make({init_body, reduce_body}); + for (int i = static_cast(axis_len) - 1; i >= 0; --i) { + ir::Var loop_var = axis[i]; + ir::Expr loop_extent = shape[i]; + body = ir::For::Make( + loop_var, + Expr(0), + loop_extent, + ir::ForType::Serial, + ir::DeviceAPI::Host, + i == static_cast(axis_len) - 1 ? body : ir::Block::Make({body})); + } + return body; + } else { + ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs); + for (int i = static_cast(axis_len) - 1; i >= 0; --i) { + ir::Var loop_var = axis[i]; + ir::Expr loop_extent = shape[i]; + body = ir::For::Make(loop_var, + Expr(0), + loop_extent, + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({body})); + } + return body; } - return body; } } // namespace ast_gen_ius diff --git a/paddle/cinn/ast_gen_ius/ast_gen.h b/paddle/cinn/ast_gen_ius/ast_gen.h index 2e9dc7fde8d8e..53b5131beef67 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.h +++ b/paddle/cinn/ast_gen_ius/ast_gen.h @@ -23,7 +23,7 @@ namespace ast_gen_ius { class AstGen { public: - static ir::Expr Build(const ir::Tensor& tensor); + static ir::Expr Build(const ir::Tensor& tensor, TensorGroup* tensor_group); }; } // namespace ast_gen_ius diff --git a/paddle/cinn/ast_gen_ius/ast_gen_test.cc b/paddle/cinn/ast_gen_ius/ast_gen_test.cc index e91c0f4ca0e28..70ebe85fb9f0d 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen_test.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen_test.cc @@ -16,6 +16,7 @@ #include #include "paddle/cinn/ast_gen_ius/ast_gen.h" +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/tensor.h" @@ -36,7 +37,8 @@ TEST(AstGen, Build) { shape, [&](const std::vector& indice) { return lang::Relu(A(indice), 0); }, "relu_test"); - Expr out = AstGen::Build(B); + TensorGroup tensor_group({B}); + Expr out = AstGen::Build(B, &tensor_group); LOG(INFO) << out; } diff --git a/paddle/cinn/ast_gen_ius/tensor_group.cc b/paddle/cinn/ast_gen_ius/tensor_group.cc index cca8b4136ba1b..2b604f2c383cb 100644 --- a/paddle/cinn/ast_gen_ius/tensor_group.cc +++ b/paddle/cinn/ast_gen_ius/tensor_group.cc @@ -75,10 +75,12 @@ std::vector TensorGroup::GetGenFuncTopoOrder( } std::vector ret; - std::vector stack; + + // Using set instead of vector/stack in order to get fix alaphbeta order topo + std::set node_set; for (const auto& name_tensor : name_to_tensor_) { if (!in_degree.count(name_tensor.first)) { - stack.emplace_back(name_tensor.first); + node_set.insert(name_tensor.first); } } @@ -90,9 +92,9 @@ std::vector TensorGroup::GetGenFuncTopoOrder( input_arg_names.erase(name); } - while (!stack.empty()) { - const std::string& cur = stack.back(); - stack.pop_back(); + while (!node_set.empty()) { + const std::string cur = *(node_set.begin()); + node_set.erase(node_set.begin()); if (!input_arg_names.count(cur)) { ret.push_back(name_to_tensor_[cur]); @@ -103,7 +105,7 @@ std::vector TensorGroup::GetGenFuncTopoOrder( if (dep_tensor_names.count(cur)) { --in_degree[dep_pair.first]; if (in_degree[dep_pair.first] == 0) { - stack.emplace_back(dep_pair.first); + node_set.insert(dep_pair.first); } } } @@ -111,15 +113,6 @@ std::vector TensorGroup::GetGenFuncTopoOrder( return ret; } -bool TensorGroup::HasMarkedReduceInit(const std::string& tensor_name) const { - return tensor_name_needs_reduce_init_.count(tensor_name); -} - -ir::Tensor TensorGroup::MarkReduceInit(const std::string& tensor_name) { - // TODO(zhhsplendid): add check - tensor_name_needs_reduce_init_.insert(tensor_name); -} - void TensorGroup::CtrlDepend(const ir::Tensor& tensor, const ir::Tensor& to_dep) { ctrl_dep_[tensor->name].insert(to_dep->name); @@ -156,8 +149,8 @@ std::string TensorGroup::GetShareMemRootName(const std::string& tensor_name) { return share_memory_tensor_[tensor_name]; } -void TensorGroup::ShareMemoryBuffer(const ir::Tensor& tensor, - const ir::Tensor& to_share) { +void TensorGroup::MarkShareMemBuffer(const ir::Tensor& tensor, + const ir::Tensor& to_share) { share_memory_tensor_[GetShareMemRootName(to_share->name)] = GetShareMemRootName(tensor->name); } diff --git a/paddle/cinn/ast_gen_ius/tensor_group.h b/paddle/cinn/ast_gen_ius/tensor_group.h index 1fa37c730c455..c6e12690e9dcc 100644 --- a/paddle/cinn/ast_gen_ius/tensor_group.h +++ b/paddle/cinn/ast_gen_ius/tensor_group.h @@ -28,54 +28,95 @@ namespace cinn { namespace ast_gen_ius { -/* Collection used for Tensors, used in AST generation */ +/** + * Collection which maintains the relation between Tensor(s) such as control + * dependency, memory sharing ... it is used in AST generation + */ class TensorGroup { public: + /** + * Constructor for a TensorGroup, the argument tensors should be output tensor + * arguments of the AST body to be generated. The dependent tensors of the + * output tensors will be collected during construction. + */ explicit TensorGroup(const std::vector& tensors); + + /** + * Destructor. + */ ~TensorGroup(); + /** + * Returns true if TensorGroup collection contains a tensor with input name. + */ bool Contain(const std::string& name) const; + /** + * Insert a Tensor into TensorGroup collection. + */ void Insert(const ir::Tensor& tensor); + /** + * Returns the Tensor in TensorGroup collection with the given name. + */ ir::Tensor Get(const std::string& name); + /** + * Returns all Tensors in TensorGroup. + */ std::set GetAllTensors(); + /** + * Mark `tensor` depends on `to_dep`. + */ void CtrlDepend(const ir::Tensor& tensor, const ir::Tensor& to_dep); + /** + * Get all tensors which the tensor with given name depends on. + */ std::set GetCrtlDepTensors(const std::string& tensor_name); + /** + * Get Union-Find set algorithm root tensor name which shares memory with the + * tensor whose name is the input. + */ std::string GetShareMemRootName(const std::string& tensor_name); - void ShareMemoryBuffer(const ir::Tensor& tensor, const ir::Tensor& to_share); + /** + * Mark two tensors share memory, it only marks using Union-Find set + * algorithm, doesn't do really memory sharing/allocation + */ + void MarkShareMemBuffer(const ir::Tensor& tensor, const ir::Tensor& to_share); + /** + * Allocate buffers for Tensors in TensorGroup, it handles the shared memory + * using Union-Find set algorithm. + */ absl::flat_hash_map AllocateBuffers(); - // Returns tensors in topological order and remove those args - // Becuase the order is used for generating function body, we don't have to - // generate args + /** + * Returns tensors in topological order and remove those args + * Becuase the order is used for generating function body, we don't have to + * generate args + */ std::vector GetGenFuncTopoOrder( const std::vector& func_args = {}); - bool HasMarkedReduceInit(const std::string& tensor_name) const; - - // Marks a tensor needs to do reduce init - ir::Tensor MarkReduceInit(const std::string& tensor_name); - private: + /** collection of output tensor names */ std::set output_tensor_names_; + /** collection of all tensors in this TensorGroup */ absl::flat_hash_map name_to_tensor_; - // Stores vector of tensor names, which the key tensor depends on + /** Stores vector of tensor names, which the key tensor depends on */ std::unordered_map> ctrl_dep_; - // Keeps Union Find Set style, each tensor name whose buffer is shared maps to - // the same name tensor + /** + * Keeps Union Find Set style, each tensor name whose buffer is shared, maps + * to the same name tensor. + */ std::unordered_map share_memory_tensor_; - - std::unordered_set tensor_name_needs_reduce_init_; }; } // namespace ast_gen_ius diff --git a/paddle/cinn/ast_gen_ius/tensor_group_test.cc b/paddle/cinn/ast_gen_ius/tensor_group_test.cc index 3711419da9c56..6c602b312ffb0 100644 --- a/paddle/cinn/ast_gen_ius/tensor_group_test.cc +++ b/paddle/cinn/ast_gen_ius/tensor_group_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -55,6 +56,82 @@ TEST(TensorGroup, Easy) { tensor_group.GetGenFuncTopoOrder({A.tensor(), B}); ASSERT_EQ(topo_tensors.size(), 1UL); ASSERT_EQ(topo_tensors[0]->name, "B"); + + ASSERT_EQ(tensor_group.GetShareMemRootName("A"), "A"); + ASSERT_EQ(tensor_group.GetShareMemRootName("B"), "B"); + tensor_group.MarkShareMemBuffer(tensor_group.Get("A"), tensor_group.Get("B")); + + absl::flat_hash_map buffered_tensors = + tensor_group.AllocateBuffers(); + ASSERT_EQ(buffered_tensors["A"]->buffer->name, + buffered_tensors["B"]->buffer->name); +} + +TEST(TensorGroup, GraphTopo) { + auto M = Expr(16); + auto N = Expr(16); + + /* + * A B + * / \ / + * C D + * \ / + * E + */ + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + Tensor C = Compute( + {M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "C"); + + Tensor D = Compute( + {M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + B(i, j); }, "D"); + + Tensor E = Compute( + {M, N}, [=](Var i, Var j) -> Expr { return C(i, j) / D(i, j); }, "E"); + + TensorGroup tensor_group({C, D, E}); + + std::vector check_names = {"A", "B", "C", "D", "E"}; + ASSERT_EQ(tensor_group.GetAllTensors().size(), check_names.size()); + for (const std::string& name : check_names) { + ASSERT_TRUE(tensor_group.Contain(name)); + ASSERT_EQ(tensor_group.Get(name)->name, name); + } + + ASSERT_TRUE(tensor_group.GetCrtlDepTensors("E").count(D)); + ASSERT_TRUE(tensor_group.GetCrtlDepTensors("E").count(C)); + ASSERT_TRUE(tensor_group.GetCrtlDepTensors("D").count(A)); + ASSERT_TRUE(tensor_group.GetCrtlDepTensors("D").count(B)); + ASSERT_TRUE(tensor_group.GetCrtlDepTensors("C").count(A)); + + std::vector topo_tensors = tensor_group.GetGenFuncTopoOrder(); + ASSERT_EQ(topo_tensors.size(), check_names.size()); + for (size_t i = 0; i < check_names.size(); ++i) { + ASSERT_EQ(topo_tensors[i]->name, check_names[i]); + } + + std::vector topo_except_argu = + tensor_group.GetGenFuncTopoOrder({A.tensor(), B.tensor()}); + ASSERT_EQ(topo_except_argu.size(), 3); + for (int i = 0; i < 3; ++i) { + ASSERT_EQ(topo_except_argu[i]->name, check_names[i + 2]); + } + + for (size_t i = 0; i < check_names.size(); ++i) { + ASSERT_EQ(tensor_group.GetShareMemRootName(check_names[i]), check_names[i]); + } + tensor_group.MarkShareMemBuffer(tensor_group.Get("A"), tensor_group.Get("B")); + tensor_group.MarkShareMemBuffer(tensor_group.Get("B"), tensor_group.Get("C")); + tensor_group.MarkShareMemBuffer(tensor_group.Get("C"), tensor_group.Get("D")); + + ASSERT_EQ(tensor_group.GetShareMemRootName("A"), + tensor_group.GetShareMemRootName("D")); + absl::flat_hash_map buffered_tensors = + tensor_group.AllocateBuffers(); + ASSERT_EQ(buffered_tensors["A"]->buffer->name, + buffered_tensors["D"]->buffer->name); } } // namespace ast_gen_ius diff --git a/paddle/cinn/backends/codegen_c_test.cc b/paddle/cinn/backends/codegen_c_test.cc old mode 100755 new mode 100644 index f0e6e238734f5..8db31b6c6007f --- a/paddle/cinn/backends/codegen_c_test.cc +++ b/paddle/cinn/backends/codegen_c_test.cc @@ -19,6 +19,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/module.h" @@ -65,8 +66,10 @@ TEST(CodeGenC, module) { target.os = Target::OS ::Linux; Module::Builder builder("module1", target); - auto stages = CreateStages({A, B, C}); - auto func = Lower("add1", stages, {A, B, C}); + ast_gen_ius::TensorGroup tensor_group({A, B, C}); + auto func = lang::LowerToAst("add1", {A, B, C}, &tensor_group); + + LOG(INFO) << "Huihuang debug: " << func << std::endl; builder.AddFunction(func); @@ -74,7 +77,7 @@ TEST(CodeGenC, module) { CodeGenC codegen(target); codegen.SetInlineBuiltinCodes(false); auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); - std::cout << "codegen C:" << std::endl << out << std::endl; + LOG(INFO) << "codegen C:" << std::endl << out << std::endl; std::string target_str = R"ROC( #include diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 7631141d115cd..3297b714630e1 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -251,11 +251,6 @@ Expr *_Tensor_::mutable_body() { CINN_NOT_IMPLEMENTED } -ir::Tensor _Tensor_::InitReduction( - ast_gen_ius::TensorGroup *tensor_group) const { - return tensor_group->MarkReduceInit(this->name); -} - ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) const { CHECK(contains_reduce_axis()) diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h index fd8e79f73ffdd..5c252d35faceb 100644 --- a/paddle/cinn/ir/tensor.h +++ b/paddle/cinn/ir/tensor.h @@ -274,8 +274,6 @@ class _Tensor_ : public ExprNode<_Tensor_> { poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; - ir::Tensor InitReduction(ast_gen_ius::TensorGroup* tensor_group) const; - /** * Create the initialization tensor. * @param stages The stages. diff --git a/paddle/cinn/lang/lower.cc b/paddle/cinn/lang/lower.cc index 667c0646c43cd..58ae00fe8771e 100644 --- a/paddle/cinn/lang/lower.cc +++ b/paddle/cinn/lang/lower.cc @@ -243,25 +243,6 @@ std::set CollectTempTensorsFromCtrlDepends( return res; } -void InitReduceTensor(TensorGroup* tensor_group, - const Tensor& tensor, - const Target& target) { - if (tensor->is_reduce_tensor()) { - tensor_group->MarkReduceInit(tensor->name); - } - auto uninited_reduce_tensors = - ir::CollectIRNodes(tensor->body(), [&](const Expr* x) { - return x && x->defined() && x->as_tensor() && - x->as_tensor()->is_reduce_tensor() && - !tensor_group->HasMarkedReduceInit(x->as_tensor()->name); - }); - for (auto& t : uninited_reduce_tensors) { - std::string reduce_name = t.as_tensor()->name; - VLOG(3) << "Init reduce tensor: " << reduce_name; - tensor_group->MarkReduceInit(reduce_name); - } -} - void InitReduceTensor(StageMap stages, const Tensor& tensor, const Target& target) { @@ -301,10 +282,6 @@ ir::LoweredFunc LowerToAst(const std::string& name, const std::vector& tensor_args, ast_gen_ius::TensorGroup* tensor_group, const Target& target) { - // Init the reduce tensors first before any process. - for (auto& t : tensor_args) { - InitReduceTensor(tensor_group, t, target); - } // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors std::set ctrl_deps = CollectTempTensorsFromCtrlDepends(tensor_group, tensor_args); diff --git a/paddle/cinn/lang/lower_tensor_group.cc b/paddle/cinn/lang/lower_tensor_group.cc index 6fb8e72f43c68..200b608387560 100644 --- a/paddle/cinn/lang/lower_tensor_group.cc +++ b/paddle/cinn/lang/lower_tensor_group.cc @@ -169,8 +169,8 @@ std::vector LowerTensorGroup::GenerateFunctionArgumentList( if (!tensor_node->buffer.defined()) { continue; } - // if a argument is already marked as kInput, mark it as kOutput and move it - // to the back. + // if a argument is already marked as kInput, mark it as kOutput and move + // it to the back. if (arg_names.count(tensor_node->buffer->name)) { auto it = std::find_if(args.begin(), args.end(), [&](const ir::Argument& x) { @@ -201,7 +201,9 @@ ir::Expr LowerTensorGroup::GenerateFunctionBody( tensor_group->GetGenFuncTopoOrder(tensor_args_); std::vector bodies; for (const ir::Tensor& tensor : ordered_tensors) { - bodies.emplace_back(ast_gen_ius::AstGen::Build(tensor)); + if (!tensor->is_placeholder_node()) { + bodies.emplace_back(ast_gen_ius::AstGen::Build(tensor, tensor_group)); + } } if (bodies.size() == 1) { return bodies[0]; diff --git a/paddle/cinn/lang/lower_test.cc b/paddle/cinn/lang/lower_test.cc old mode 100755 new mode 100644 index 431d73d075be6..e97d0f596a7ea --- a/paddle/cinn/lang/lower_test.cc +++ b/paddle/cinn/lang/lower_test.cc @@ -159,6 +159,7 @@ TEST(lower, temp_buffer_collects) { } TEST(lower_to_ast, basic) { + Context::Global().ResetNameId(); auto M = Expr(100); auto N = Expr(15); @@ -169,11 +170,12 @@ TEST(lower_to_ast, basic) { ast_gen_ius::TensorGroup tensor_group({B}); - auto lower_funcs = LowerToAst("cal_B", {A, B}, &tensor_group); + ir::LoweredFunc lower_func = LowerToAst("cal_B", {A, B}, &tensor_group); - LOG(INFO) << "lower_func " << lower_funcs; + LOG(INFO) << "lower_func " << lower_func; auto out = R"ROC( +function cal_B (_A, _B) { serial for (i, 0, 100) { @@ -184,7 +186,81 @@ TEST(lower_to_ast, basic) { } } )ROC"; - TEST_SOUTPUT(lower_funcs->body, out); + TEST_SOUTPUT(lower_func, out); +} + +TEST(lower_to_ast, three_dim) { + Context::Global().ResetNameId(); + Expr M(100); + Expr N(15); + Expr K(200); + + Placeholder A("A", {Expr(M), Expr(N)}); + Placeholder B("B", {Expr(N), Expr(K)}); + + auto C = Compute( + {M, N, K}, + [=](Var i, Var j, Var k) -> Expr { return A(i, j) * B(j, k); }, + "C"); + + ast_gen_ius::TensorGroup tensor_group({C}); + + ir::LoweredFunc lower_func = LowerToAst("cal_C", {A, B, C}, &tensor_group); + + LOG(INFO) << "func:\n" << lower_func << std::endl; + + auto out = R"ROC( +function cal_C (_A, _B, _C) +{ + serial for (i, 0, 100) + { + serial for (j, 0, 15) + { + serial for (k, 0, 200) + { + C[i, j, k] = (A[i, j] * B[j, k]) + } + } + } +} +)ROC"; + TEST_SOUTPUT(lower_func, out); +} + +TEST(lower_to_ast, matmul_with_reduce_sum) { + Context::Global().ResetNameId(); + Placeholder A("A", {Expr(100), Expr(20)}); + Placeholder B("B", {Expr(20), Expr(50)}); + + Target target{}; + // C = A * B + Var k(20, "k0"); + Tensor C = Compute( + {Expr(100), Expr(50)}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); + + ast_gen_ius::TensorGroup tensor_group({C}); + ir::LoweredFunc lower_func = LowerToAst("matmul", {A, B, C}, &tensor_group); + LOG(INFO) << "func:\n" << lower_func << std::endl; + + auto out = R"ROC( +function matmul (_A, _B, _C) +{ + serial for (i, 0, 100) + { + serial for (j, 0, 50) + { + C__reduce_init[i, j] = 0.00000000f + serial for (k0, 0, 20) + { + C[i, j] = (C[i, j] + (A[i, k0] * B[k0, j])) + } + } + } +} +)ROC"; + TEST_SOUTPUT(lower_func, out); } } // namespace lang