diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 3013408d4bfef..550296807c543 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -15,7 +15,7 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl, +from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type @@ -489,6 +489,12 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr + if (isinstance(node.func, ast.Attribute) and + (func == Matrix + or func == Vector)) and impl.current_cfg().real_matrix: + node.ptr = matrix.make_matrix(*args, **keywords) + return node.ptr + if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): return node.ptr diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 4a5ae39f9ea95..027696989dc9d 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -15,7 +15,8 @@ from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, - _IntermediateMatrix, _MatrixFieldElement) + _IntermediateMatrix, _MatrixFieldElement, + make_matrix) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, MeshReorderedMatrixFieldProxy, @@ -36,6 +37,12 @@ def expr_init_local_tensor(shape, element_type, elements): get_runtime().get_current_src_info()) +@taichi_scope +def make_matrix_expr(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().make_matrix_expr( + shape, element_type, elements) + + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( @@ -49,6 +56,13 @@ def expr_init(rhs): if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): return Matrix(*rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, Matrix): + if current_cfg().real_matrix: + if rhs.ndim == 1: + entries = [rhs(i) for i in range(rhs.n)] + else: + entries = [[rhs(i, j) for j in range(rhs.m)] + for i in range(rhs.n)] + return make_matrix(entries) return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): return rhs diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 9e019f29db6c5..2bbb23c1acf93 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -98,6 +98,21 @@ def prop_setter(instance, value): return cls +def make_matrix(arr, dt=None): + assert len(arr) > 0, "Cannot create empty matrix" + is_matrix = isinstance(arr[0], Iterable) + if dt is None: + dt = _make_entries_initializer(is_matrix).infer_dt(arr) + if not is_matrix: + return impl.Expr( + impl.make_matrix_expr([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr])) + return impl.Expr( + impl.make_matrix_expr( + [len(arr), len(arr[0])], dt, + [expr.Expr(elt).ptr for row in arr for elt in row])) + + class _MatrixBaseImpl: def __init__(self, m, n, entries): self.m = m diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 473ebe12ae58c..51edc1bcc9ee1 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -166,6 +166,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->indices.exprs); } + void visit(MatrixExpression *expr) override { + emit(ExprOpCode::MatrixExpression); + emit(expr->dt); + for (auto elt : expr->elements) { + emit(elt); + } + } + void visit(StrideExpression *expr) override { emit(ExprOpCode::StrideExpression); emit(expr->var); diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 9cbef443844a2..3207e0d887834 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -65,6 +65,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config->demote_no_access_mesh_fors); serializer(config->experimental_auto_mesh_local); serializer(config->auto_mesh_local_default_occupacy); + serializer(config->real_matrix); serializer.finalize(); return serializer.data; diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 771f9f055dbf0..a6ddb17d88555 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,9 +124,11 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); - auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); - // Return type is [array_size x type]*. + auto type = kernel->program->config.real_matrix + ? tlctx->get_data_type(tensor_type) + : tlctx->get_data_type(tensor_type->get_element_type()); + // Return type is vector* if use real matrix. + // otherwise the return type is [type * array_size]* if (stmt->is_shared) { size_t data_element_size = tlctx->get_type_size( tlctx->get_data_type(tensor_type->get_element_type())); @@ -148,7 +150,12 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size); + if (kernel->program->config.real_matrix) + llvm_val[stmt] = + create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + else + llvm_val[stmt] = create_entry_block_alloca( + type, 0, tlctx->get_constant(tensor_type->get_num_elements())); } } else { TI_ASSERT(stmt->width() == 1); @@ -688,6 +695,13 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { return llvm::Type::getDoubleTy(*llvm_context); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); + } else if (dt->is()) { + TI_ASSERT_INFO(kernel->program->config.real_matrix, + "Real matrix not enabled but got TensorType"); + auto tensor_type = dt->cast(); + auto element_type = llvm_type(tensor_type->get_element_type()); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), + /*scalable=*/false); } else { TI_NOT_IMPLEMENTED; } @@ -796,16 +810,29 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { TI_ASSERT(stmt->width() == 1); std::vector args; std::string formats; + auto value_for_printf = [this](llvm::Value *to_print, DataType dtype) { + if (dtype->is_primitive(PrimitiveTypeID::f32) || + dtype->is_primitive(PrimitiveTypeID::f16)) + return this->builder->CreateFPExt( + to_print, this->tlctx->get_data_type(PrimitiveType::f64)); + return to_print; + }; for (auto const &content : stmt->contents) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt(value, - tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); - formats += data_type_format(arg_stmt->ret_type); + if (arg_stmt->ret_type->is()) { + auto dtype = arg_stmt->ret_type->cast(); + auto elem_type = dtype->get_element_type(); + for (int i = 0; i < dtype->get_num_elements(); ++i) { + auto elem_value = builder->CreateExtractElement(value, i); + args.push_back(value_for_printf(elem_value, elem_type)); + } + formats += data_type_format(arg_stmt->ret_type); + } else { + args.push_back(value_for_printf(value, arg_stmt->ret_type)); + formats += data_type_format(arg_stmt->ret_type); + } } else { auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); @@ -2515,6 +2542,16 @@ void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { llvm_val[stmt] = get_arg(2); } +void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) { + auto type = tlctx->get_data_type(stmt->ret_type->as()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < stmt->values.size(); ++i) { + auto *elem = llvm_val[stmt->values[i]]; + vec = builder->CreateInsertElement(vec, elem, i); + } + llvm_val[stmt] = vec; +} + void TaskCodeGenLLVM::eliminate_unused_functions() { TaichiLLVMContext::eliminate_unused_functions( module.get(), [&](std::string func_name) { diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 6f97ed7dff0f4..356866b12b8c4 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -369,6 +369,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(ReferenceStmt *stmt) override; + void visit(MatrixInitStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); llvm::Value *create_mesh_xlogue(std::unique_ptr &block); diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 9b20ba86bd80a..ac3d3b7bc9b1b 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -7,6 +7,7 @@ PER_EXPRESSION(InternalFuncCallExpression) PER_EXPRESSION(ExternalTensorExpression) PER_EXPRESSION(GlobalVariableExpression) PER_EXPRESSION(IndexExpression) +PER_EXPRESSION(MatrixExpression) PER_EXPRESSION(StrideExpression) PER_EXPRESSION(RangeAssumptionExpression) PER_EXPRESSION(LoopUniqueExpression) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index fe12a8941f7f5..05056ce46b9ae 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt) PER_STATEMENT(AssertStmt) PER_STATEMENT(ExternalFuncCallStmt) PER_STATEMENT(ExternalTensorShapeAlongAxisStmt) +PER_STATEMENT(MatrixInitStmt) // Locals with reverse-mode autodiff PER_STATEMENT(AdStackAllocaStmt) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 74887c890e099..a32a0468d6907 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -110,6 +110,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } } + void visit(MatrixExpression *expr) override { + emit('['); + emit_vector(expr->elements); + emit(']'); + emit(fmt::format(" (dt={})", expr->dt->to_string())); + } + void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index a553e3ca0da4e..9ef38c9701896 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -432,6 +432,26 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, return ctx->push_back(var->stmt, offset_stmt); } +void MatrixExpression::type_check(CompileConfig *config) { + // TODO: typecheck matrix + for (auto &arg : elements) { + TI_ASSERT_TYPE_CHECKED(arg); + } + ret_type = dt; +} + +void MatrixExpression::flatten(FlattenContext *ctx) { + // TODO: implement flatten + TI_ASSERT(this->dt->is()); + std::vector values; + for (auto &elt : elements) { + flatten_rvalue(elt, ctx); + values.push_back(elt->stmt); + } + stmt = ctx->push_back(values); + stmt->ret_type = this->dt; +} + bool IndexExpression::is_field() const { return var.is(); } @@ -970,6 +990,12 @@ Expr ASTBuilder::expr_alloca() { return var; } +Expr ASTBuilder::make_matrix_expr(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { + return Expr(std::make_shared(elements, shape, dt)); +} + Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 4e1e9041531ac..d4da9b635820f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -504,6 +504,29 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +/** + * Creating a local matrix; + * lowered from ti.Matrix with real_matrix=True + */ +class MatrixExpression : public Expression { + public: + std::vector elements; + DataType dt; + + MatrixExpression(const std::vector &elements, + std::vector shape, + DataType element_type) + : elements(elements) { + this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class IndexExpression : public Expression { public: // `var` is one of GlobalVariableExpression, ExternalTensorExpression, @@ -865,6 +888,9 @@ class ASTBuilder { const std::function &func); Expr make_id_expr(const std::string &name); + Expr make_matrix_expr(const std::vector &shape, + const DataType &dt, + const std::vector &elements); Expr insert_thread_idx_expr(); Expr insert_patch_idx_expr(); void create_kernel_exprgroup_return(const ExprGroup &group); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 9a2ea841e6a66..11c87dd98b67c 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1807,5 +1807,20 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * Initialization of a local matrix + */ +class MatrixInitStmt : public Stmt { + public: + std::vector values; + + MatrixInitStmt(const std::vector &values) : values(values) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(ret_type, values); + TI_DEFINE_ACCEPT_AND_CLONE +}; + } // namespace lang } // namespace taichi diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index e49428b022445..76ee3aa1f7b30 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -53,6 +53,36 @@ int data_type_size(DataType t) { } } +std::string tensor_type_format_helper(const std::vector &shape, + std::string format_str, + int dim) { + std::string fmt = "["; + for (int i = 0; i < shape[dim]; ++i) { + if (dim != shape.size() - 1) { + fmt += tensor_type_format_helper(shape, format_str, dim + 1); + } else { + fmt += format_str; + } + if (i != shape[dim] - 1) { + fmt += ", "; + if (dim == 0 && dim != shape.size() - 1) { + fmt += "\n"; + } + } + } + fmt += "]"; + return fmt; +} + +std::string tensor_type_format(DataType t) { + TI_ASSERT(t->is()); + auto tensor_type = t->as(); + auto shape = tensor_type->get_shape(); + auto element_type = tensor_type->get_element_type(); + auto element_type_format = data_type_format(element_type); + return tensor_type_format_helper(shape, element_type_format, 0); +} + std::string data_type_format(DataType dt) { if (dt->is_primitive(PrimitiveTypeID::i16)) { return "%hd"; @@ -79,6 +109,8 @@ std::string data_type_format(DataType dt) { // TaskCodeGenLLVM::visit(PrintStmt *stmt) and // TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details. return "%f"; + } else if (dt->is()) { + return tensor_type_format(dt); } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index a15ebcf7f9c5a..6f3d7a16deb2f 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -48,6 +48,7 @@ CompileConfig::CompileConfig() { detect_read_only = true; ndarray_use_cached_allocator = true; use_mesh = false; + real_matrix = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 7820834066041..14f17e419fd6d 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -44,6 +44,7 @@ struct CompileConfig { bool detect_read_only; bool ndarray_use_cached_allocator; bool use_mesh; + bool real_matrix; DataType default_fp; DataType default_ip; DataType default_up; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 563362125bffc..8b44098656b3d 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -200,6 +200,7 @@ void export_lang(py::module &m) { .def_readwrite("ndarray_use_cached_allocator", &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("use_mesh", &CompileConfig::use_mesh) + .def_readwrite("real_matrix", &CompileConfig::real_matrix) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("quant_opt_store_fusion", @@ -290,6 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) + .def("make_matrix_expr", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 65dc87c88049b..8e99f32503b58 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -64,7 +64,7 @@ namespace lang { using namespace llvm; TaichiLLVMContext::TaichiLLVMContext(CompileConfig *config, Arch arch) - : arch_(arch) { + : config_(config), arch_(arch) { TI_TRACE("Creating Taichi llvm context for arch: {}", arch_name(arch)); main_thread_id_ = std::this_thread::get_id(); main_thread_data_ = get_this_thread_data(); @@ -142,6 +142,13 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); + } else if (dt->is()) { + TI_ASSERT_INFO(config_->real_matrix, + "Real matrix not enabled but got TensorType"); + auto vectorty = dt->as(); + auto dtype = this->get_data_type(vectorty->get_element_type()); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), + /*scalable=*/false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED diff --git a/taichi/runtime/llvm/llvm_context.h b/taichi/runtime/llvm/llvm_context.h index afcccebbbbcfe..ae87699f48484 100644 --- a/taichi/runtime/llvm/llvm_context.h +++ b/taichi/runtime/llvm/llvm_context.h @@ -33,6 +33,7 @@ class TaichiLLVMContext { std::unique_ptr struct_module{nullptr}; ~ThreadLocalData(); }; + CompileConfig *config_; public: std::unique_ptr jit{nullptr}; diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 4ce787066f230..b8bdf2a32f698 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -213,6 +213,11 @@ class AlgSimp : public BasicStmtVisitor { } void visit(BinaryOpStmt *stmt) override { + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); if (stmt->width() != 1) { diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 3176d8f576949..f6d696c4da498 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -108,6 +108,13 @@ class DIE : public IRVisitor { } stmt->all_blocks_accept(this, true); } + + void visit(MatrixInitStmt *stmt) override { + register_usage(stmt); + for (auto &elts : stmt->values) { + elts->accept(this); + } + } }; namespace irpass { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index ca462e42773e8..97538ee93acce 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -794,6 +794,19 @@ class IRPrinter : public IRVisitor { print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name()); } + void visit(MatrixInitStmt *stmt) override { + std::string result = ""; + result += fmt::format("{}{} = [", stmt->type_hint(), stmt->name()); + for (int i = 0; i < stmt->values.size(); ++i) { + result += stmt->values[i]->name(); + if (i != stmt->values.size() - 1) { + result += ", "; + } + } + result += "]"; + print(result); + } + private: std::string expr_to_string(Expr &expr) { return expr_to_string(expr.expr.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 2d691c573055b..b6cb7112e1cf8 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -555,6 +555,19 @@ class TypeCheck : public IRVisitor { stmt->ret_type = stmt->var->ret_type; stmt->ret_type.set_is_pointer(true); } + + void visit(MatrixInitStmt *stmt) override { + TI_ASSERT_INFO(stmt->ret_type->is(), + "Matrix should have tensor type, got {}", + stmt->ret_type->to_string()); + auto tensor_type = stmt->ret_type->as(); + auto element_dtype = tensor_type->get_element_type(); + for (int i = 0; i < stmt->values.size(); ++i) { + if (element_dtype != stmt->values[i]->ret_type) { + cast(stmt->values[i], element_dtype); + } + } + } }; namespace irpass {