diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 466886dfbd5c1..f98d23fa2ae22 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -104,9 +104,7 @@ def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod): def decl_ret(dtype, real_func=False): if isinstance(dtype, StructType): - for member in dtype.members.values(): - decl_ret(member, real_func) - return + dtype = dtype.dtype if isinstance(dtype, MatrixType): if real_func: for i in range(dtype.n * dtype.m): diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 91e8b00246e29..30bd64fb82d20 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -277,7 +277,7 @@ def func_call_rvalue(self, key, args): if id(self.return_type) in primitive_types.type_ids: return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0, ))) if isinstance(self.return_type, StructType): - return self.return_type.from_real_func_ret(func_call)[0] + return self.return_type.from_real_func_ret(func_call, (0, )) raise TaichiTypeError(f"Unsupported return type: {self.return_type}") def do_compile(self, key, args): diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 178d54b62ee53..7c56dfc81f601 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1509,12 +1509,13 @@ def __call__(self, *args): # type cast return self.cast(Matrix(entries, dt=self.dtype, ndim=self.ndim)) - def from_real_func_ret(self, func_ret, ret_index=0): + def from_real_func_ret(self, func_ret, ret_index=()): return self([ - expr.Expr(ti_python_core.make_get_element_expr( - func_ret.ptr, (i, ))) - for i in range(ret_index, ret_index + self.m * self.n) - ]), ret_index + self.m * self.n + expr.Expr( + ti_python_core.make_get_element_expr(func_ret.ptr, + ret_index + (i, ))) + for i in range(self.m * self.n) + ]) def cast(self, mat): if in_python_scope(): diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index b819cb940a2d9..6a0c585f1a4e9 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -704,21 +704,20 @@ def __call__(self, *args, **kwargs): struct = self.cast(entries) return struct - def from_real_func_ret(self, func_ret, ret_index=0): + def from_real_func_ret(self, func_ret, ret_index=()): d = {} items = self.members.items() for index, pair in enumerate(items): name, dtype = pair if isinstance(dtype, CompoundType): - d[name], ret_index = dtype.from_real_func_ret( - func_ret, ret_index) + d[name] = dtype.from_real_func_ret(func_ret, + ret_index + (index, )) else: d[name] = expr.Expr( _ti_core.make_get_element_expr(func_ret.ptr, - (ret_index, ))) - ret_index += 1 + ret_index + (index, ))) - return Struct(d), ret_index + return Struct(d) def cast(self, struct): # sanity check members diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index b0d4ebd89dfde..97da38c9c3c4e 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1271,18 +1271,10 @@ void TaskCodeGenLLVM::visit(ReturnStmt *stmt) { if (std::any_of(types.begin(), types.end(), [](const DataType &t) { return t.is_pointer(); })) { TI_NOT_IMPLEMENTED - } else if (now_real_func) { - TI_ASSERT(stmt->values.size() == now_real_func->rets.size()); - auto *result_buf = call("RuntimeContext_get_result_buffer", get_context()); - auto *ret_type = get_real_func_ret_type(now_real_func); - result_buf = builder->CreatePointerCast( - result_buf, llvm::PointerType::get(ret_type, 0)); - for (int i = 0; i < stmt->values.size(); i++) { - auto *gep = - builder->CreateGEP(ret_type, result_buf, - {tlctx->get_constant(0), tlctx->get_constant(i)}); - builder->CreateStore(llvm_val[stmt->values[i]], gep); - } + } else if (current_real_func) { + TI_ASSERT(stmt->values.size() == + current_real_func->ret_type->get_num_elements()); + create_return(stmt->values); } else { TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value); int idx{0}; @@ -2707,11 +2699,11 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) { auto guard = get_function_creation_guard( {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)}, stmt->func->get_name()); - Function *old_real_func = now_real_func; - now_real_func = stmt->func; + Function *old_real_func = current_real_func; + current_real_func = stmt->func; func_map.insert({stmt->func, guard.body}); stmt->func->ir->accept(this); - now_real_func = old_real_func; + current_real_func = old_real_func; } llvm::Function *llvm_func = func_map[stmt->func]; auto *new_ctx = call("allocate_runtime_context", get_runtime()); @@ -2749,14 +2741,50 @@ void TaskCodeGenLLVM::visit(GetElementStmt *stmt) { llvm_val[stmt] = val; } -llvm::Type *TaskCodeGenLLVM::get_real_func_ret_type(Function *real_func) { - std::vector tps; - for (auto &ret : real_func->rets) { - tps.push_back(tlctx->get_data_type(ret.dt)); +void TaskCodeGenLLVM::create_return(llvm::Value *buffer, + llvm::Type *buffer_type, + const std::vector &elements, + const Type *current_type, + int ¤t_element, + std::vector ¤t_index) { + if (auto primitive_type = current_type->cast()) { + TI_ASSERT((Type *)elements[current_element]->ret_type == current_type); + auto *gep = builder->CreateGEP(buffer_type, buffer, current_index); + builder->CreateStore(llvm_val[elements[current_element]], gep); + current_element++; + } else if (auto struct_type = current_type->cast()) { + int i = 0; + for (const auto &element_type : struct_type->elements()) { + current_index.push_back(tlctx->get_constant(i++)); + create_return(buffer, buffer_type, elements, element_type, + current_element, current_index); + current_index.pop_back(); + } + } else { + auto tensor_type = current_type->as(); + int num_elements = tensor_type->get_num_elements(); + Type *element_type = tensor_type->get_element_type(); + for (int i = 0; i < num_elements; i++) { + current_index.push_back(tlctx->get_constant(i)); + create_return(buffer, buffer_type, elements, element_type, + current_element, current_index); + current_index.pop_back(); + } } - return llvm::StructType::get(*llvm_context, tps); } +void TaskCodeGenLLVM::create_return(const std::vector &elements) { + auto buffer = call("RuntimeContext_get_result_buffer", get_context()); + auto ret_type = current_real_func->ret_type; + auto buffer_type = tlctx->get_data_type(ret_type); + buffer = builder->CreatePointerCast(buffer, + llvm::PointerType::get(buffer_type, 0)); + int current_element = 0; + std::vector current_index = {tlctx->get_constant(0)}; + create_return(buffer, buffer_type, elements, ret_type, current_element, + current_index); +}; + LLVMCompiledTask LLVMCompiledTask::clone() const { return {tasks, llvm::CloneModule(*module), used_tree_ids, struct_for_tls_sizes}; diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 52c9f2694d319..758a913813262 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -60,7 +60,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { bool returned{false}; std::unordered_set used_tree_ids; std::unordered_set struct_for_tls_sizes; - Function *now_real_func{nullptr}; + Function *current_real_func{nullptr}; std::unordered_map> loop_vars_llvm; @@ -95,8 +95,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Type *get_mesh_xlogue_function_type(); - llvm::Type *get_real_func_ret_type(Function *real_func); - llvm::Value *get_root(int snode_tree_id); llvm::Value *get_runtime(); @@ -138,6 +136,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *create_print(std::string tag, llvm::Value *value); + void create_return(const std::vector &elements); + llvm::Value *cast_pointer(llvm::Value *val, std::string dest_ty_name, int addr_space = 0); @@ -402,6 +402,14 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); ~TaskCodeGenLLVM() override = default; + + private: + void create_return(llvm::Value *buffer, + llvm::Type *buffer_type, + const std::vector &elements, + const Type *current_type, + int ¤t_element, + std::vector ¤t_index); }; } // namespace taichi::lang