From 3c2e5e9b2e68362739a67407c5bf4a2eb4781a7c Mon Sep 17 00:00:00 2001 From: Taichi Gardener <62079278+taichi-gardener@users.noreply.github.com> Date: Thu, 15 Oct 2020 12:41:51 -0400 Subject: [PATCH] [type] [refactor] Remove DataType::data_type (#1960) --- taichi/backends/cc/codegen_cc.cpp | 8 +- taichi/backends/cpu/codegen_cpu.cpp | 4 +- taichi/backends/cuda/codegen_cuda.cpp | 41 +++++---- taichi/backends/metal/codegen_metal.cpp | 8 +- taichi/backends/opengl/codegen_opengl.cpp | 15 ++-- taichi/codegen/codegen_llvm.cpp | 103 +++++++++++----------- taichi/ir/control_flow_graph.cpp | 3 +- taichi/ir/ir.cpp | 2 +- taichi/ir/ir.h | 2 +- taichi/ir/state_machine.cpp | 2 +- taichi/ir/statements.h | 4 +- taichi/ir/type.cpp | 2 +- taichi/ir/type.h | 5 +- taichi/transforms/alg_simp.cpp | 39 ++++---- taichi/transforms/auto_diff.cpp | 21 +++-- taichi/transforms/binary_op_simplify.cpp | 8 +- taichi/transforms/constant_fold.cpp | 7 +- taichi/transforms/lower_access.cpp | 2 +- taichi/transforms/lower_ast.cpp | 2 +- taichi/transforms/make_thread_local.cpp | 2 +- taichi/transforms/offload.cpp | 4 +- taichi/transforms/simplify.cpp | 6 +- taichi/transforms/type_check.cpp | 81 ++++++++--------- 23 files changed, 176 insertions(+), 195 deletions(-) diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index b0fbbcea976bf..8e0709c198d66 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -401,7 +401,7 @@ class CCTransformer : public IRVisitor { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); - format += data_type_format(arg_stmt->ret_type.data_type); + format += data_type_format(arg_stmt->ret_type); values.push_back(arg_stmt->raw_name()); } else { @@ -527,10 +527,8 @@ class CCTransformer : public IRVisitor { } void visit(RandStmt *stmt) override { - auto var = define_var(cc_data_type_name(stmt->ret_type.data_type), - stmt->raw_name()); - emit("{} = Ti_rand_{}();", var, - data_type_short_name(stmt->ret_type.data_type)); + auto var = define_var(cc_data_type_name(stmt->ret_type), stmt->raw_name()); + emit("{} = Ti_rand_{}();", var, data_type_short_name(stmt->ret_type)); } void visit(StackAllocaStmt *stmt) override { diff --git a/taichi/backends/cpu/codegen_cpu.cpp b/taichi/backends/cpu/codegen_cpu.cpp index 1d2686fa26ef8..118f818faa52c 100644 --- a/taichi/backends/cpu/codegen_cpu.cpp +++ b/taichi/backends/cpu/codegen_cpu.cpp @@ -98,13 +98,13 @@ class CodeGenLLVMCPU : public CodeGenLLVM { for (auto s : stmt->arg_stmts) { TI_ASSERT(s->width() == 1); - arg_types.push_back(tlctx->get_data_type(s->ret_type.data_type)); + arg_types.push_back(tlctx->get_data_type(s->ret_type)); arg_values.push_back(llvm_val[s]); } for (auto s : stmt->output_stmts) { TI_ASSERT(s->width() == 1); - auto t = tlctx->get_data_type(s->ret_type.data_type); + auto t = tlctx->get_data_type(s->ret_type); auto ptr = llvm::PointerType::get(t, 0); arg_types.push_back(ptr); arg_values.push_back(llvm_val[s]); diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 0356c3d7ebfad..c867cba851eb9 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -115,11 +115,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); - formats += data_type_format(arg_stmt->ret_type.data_type); + formats += data_type_format(arg_stmt->ret_type); - auto value_type = tlctx->get_data_type(arg_stmt->ret_type.data_type); + auto value_type = tlctx->get_data_type(arg_stmt->ret_type); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type.data_type == PrimitiveType::f32) { + if (arg_stmt->ret_type == PrimitiveType::f32) { value_type = tlctx->get_data_type(PrimitiveType::f64); value = builder->CreateFPExt(value, value_type); } @@ -158,7 +158,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { void emit_extra_unary(UnaryOpStmt *stmt) override { // functions from libdevice auto input = llvm_val[stmt->operand]; - auto input_taichi_type = stmt->operand->ret_type.data_type; + auto input_taichi_type = stmt->operand->ret_type; auto op = stmt->op_type; #define UNARY_STD(x) \ @@ -232,16 +232,16 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { for (int l = 0; l < stmt->width(); l++) { llvm::Value *old_value; if (stmt->op_type == AtomicOpType::add) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { + } else if (stmt->val->ret_type == PrimitiveType::f32) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest], llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { + } else if (stmt->val->ret_type == PrimitiveType::f64) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest], llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent); @@ -249,16 +249,16 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::min) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { + } else if (stmt->val->ret_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { + } else if (stmt->val->ret_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -266,16 +266,16 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::max) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { + } else if (stmt->val->ret_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { + } else if (stmt->val->ret_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -283,7 +283,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_and) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::And, llvm_val[stmt->dest], llvm_val[stmt->val], @@ -292,7 +292,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_or) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Or, llvm_val[stmt->dest], llvm_val[stmt->val], @@ -301,7 +301,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_xor) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Xor, llvm_val[stmt->dest], llvm_val[stmt->val], @@ -317,10 +317,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { } void visit(RandStmt *stmt) override { - llvm_val[stmt] = - create_call(fmt::format("cuda_rand_{}", - data_type_short_name(stmt->ret_type.data_type)), - {get_context()}); + llvm_val[stmt] = create_call( + fmt::format("cuda_rand_{}", data_type_short_name(stmt->ret_type)), + {get_context()}); } void visit(RangeForStmt *for_stmt) override { create_naive_range_for(for_stmt); @@ -397,7 +396,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { if (should_cache_as_read_only) { // Issue an CUDA "__ldg" instruction so that data are cached in // the CUDA read-only data cache. - auto dtype = stmt->ret_type.data_type; + auto dtype = stmt->ret_type; auto llvm_dtype = llvm_type(dtype); auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0); llvm::Intrinsic::ID intrin; diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 3287296bb1d11..14a51441b49bd 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -276,7 +276,7 @@ class KernelCodegen : public IRVisitor { } } else if (opty == SNodeOpType::append) { TI_ASSERT(is_dynamic); - TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); + TI_ASSERT(stmt->ret_type == PrimitiveType::i32); emit("{} = {}.append({});", result_var, parent, stmt->val->raw_name()); } else if (opty == SNodeOpType::length) { TI_ASSERT(is_dynamic); @@ -412,7 +412,7 @@ class KernelCodegen : public IRVisitor { const auto bin_name = bin->raw_name(); const auto op_type = bin->op_type; if (op_type == BinaryOpType::floordiv) { - if (is_integral(bin->ret_type.data_type)) { + if (is_integral(bin->ret_type)) { emit("const {} {} = ifloordiv({}, {});", dt_name, bin_name, lhs_name, rhs_name); } else { @@ -421,7 +421,7 @@ class KernelCodegen : public IRVisitor { } return; } - if (op_type == BinaryOpType::pow && is_integral(bin->ret_type.data_type)) { + if (op_type == BinaryOpType::pow && is_integral(bin->ret_type)) { // TODO(k-ye): Make sure the type is not i64? emit("const {} {} = pow_i32({}, {});", dt_name, bin_name, lhs_name, rhs_name); @@ -604,7 +604,7 @@ class KernelCodegen : public IRVisitor { void visit(RandStmt *stmt) override { emit("const auto {} = metal_rand_{}({});", stmt->raw_name(), - data_type_short_name(stmt->ret_type.data_type), kRandStateVarName); + data_type_short_name(stmt->ret_type), kRandStateVarName); } void visit(PrintStmt *stmt) override { diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index dd2d78eea1466..03c9686e0c31f 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -267,8 +267,8 @@ class KernelGen : public IRVisitor { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); emit("_msg_set_{}({}, {}, {});", - opengl_data_type_short_name(arg_stmt->ret_type.data_type), - msgid_name, i, arg_stmt->short_name()); + opengl_data_type_short_name(arg_stmt->ret_type), msgid_name, i, + arg_stmt->short_name()); } else { auto str = std::get(content); @@ -281,9 +281,8 @@ class KernelGen : public IRVisitor { void visit(RandStmt *stmt) override { used.random = true; - emit("{} {} = _rand_{}();", opengl_data_type_name(stmt->ret_type.data_type), - stmt->short_name(), - opengl_data_type_short_name(stmt->ret_type.data_type)); + emit("{} {} = _rand_{}();", opengl_data_type_name(stmt->ret_type), + stmt->short_name(), opengl_data_type_short_name(stmt->ret_type)); } void visit(LinearizeStmt *stmt) override { @@ -361,7 +360,7 @@ class KernelGen : public IRVisitor { } } else if (stmt->op_type == SNodeOpType::is_active) { - TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); + TI_ASSERT(stmt->ret_type == PrimitiveType::i32); if (stmt->snode->type == SNodeType::dense || stmt->snode->type == SNodeType::root) { emit("int {} = 1;", stmt->short_name()); @@ -374,7 +373,7 @@ class KernelGen : public IRVisitor { } else if (stmt->op_type == SNodeOpType::append) { TI_ASSERT(stmt->snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); + TI_ASSERT(stmt->ret_type == PrimitiveType::i32); emit("int {} = atomicAdd(_data_i32_[{} >> 2], 1);", stmt->short_name(), get_snode_meta_address(stmt->snode)); auto dt = stmt->val->element_type(); @@ -388,7 +387,7 @@ class KernelGen : public IRVisitor { } else if (stmt->op_type == SNodeOpType::length) { TI_ASSERT(stmt->snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); + TI_ASSERT(stmt->ret_type == PrimitiveType::i32); emit("int {} = _data_i32_[{} >> 2];", stmt->short_name(), get_snode_meta_address(stmt->snode)); diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 84028c6b18816..6ada24eeebc2e 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -130,22 +130,22 @@ void CodeGenLLVM::visit(Block *stmt_list) { void CodeGenLLVM::visit(AllocaStmt *stmt) { TI_ASSERT(stmt->width() == 1); - llvm_val[stmt] = create_entry_block_alloca(stmt->ret_type.data_type, - stmt->ret_type.is_pointer()); + llvm_val[stmt] = + create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer()); // initialize as zero if element is not a pointer if (!stmt->ret_type.is_pointer()) - builder->CreateStore(tlctx->get_constant(stmt->ret_type.data_type, 0), + builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0), llvm_val[stmt]); } void CodeGenLLVM::visit(RandStmt *stmt) { - llvm_val[stmt] = create_call( - fmt::format("rand_{}", data_type_short_name(stmt->ret_type.data_type))); + llvm_val[stmt] = + create_call(fmt::format("rand_{}", data_type_short_name(stmt->ret_type))); } void CodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { auto input = llvm_val[stmt->operand]; - auto input_taichi_type = stmt->operand->ret_type.data_type; + auto input_taichi_type = stmt->operand->ret_type; auto op = stmt->op_type; auto input_type = input->getType(); @@ -298,7 +298,7 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { } if (stmt->op_type == UnaryOpType::cast_value) { llvm::CastInst::CastOps cast_op; - auto from = stmt->operand->ret_type.data_type; + auto from = stmt->operand->ret_type; auto to = stmt->cast_type; TI_ASSERT(from != to); if (is_real(from) != is_real(to)) { @@ -332,7 +332,7 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { } } } else if (stmt->op_type == UnaryOpType::cast_bits) { - TI_ASSERT(data_type_size(stmt->ret_type.data_type) == + TI_ASSERT(data_type_size(stmt->ret_type) == data_type_size(stmt->cast_type)); llvm_val[stmt] = builder->CreateBitCast( llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); @@ -341,11 +341,11 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { module.get(), llvm::Intrinsic::sqrt, input->getType()); auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); llvm_val[stmt] = builder->CreateFDiv( - tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate); + tlctx->get_constant(stmt->ret_type, 1.0), intermediate); } else if (op == UnaryOpType::bit_not) { llvm_val[stmt] = builder->CreateNot(input); } else if (op == UnaryOpType::neg) { - if (is_real(stmt->operand->ret_type.data_type)) { + if (is_real(stmt->operand->ret_type)) { llvm_val[stmt] = builder->CreateFNeg(input, "neg"); } else { llvm_val[stmt] = builder->CreateNeg(input, "neg"); @@ -359,9 +359,9 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { void CodeGenLLVM::visit(BinaryOpStmt *stmt) { auto op = stmt->op_type; - auto ret_type = stmt->ret_type.data_type; + auto ret_type = stmt->ret_type; if (op == BinaryOpType::add) { - if (is_real(stmt->ret_type.data_type)) { + if (is_real(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -369,7 +369,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::sub) { - if (is_real(stmt->ret_type.data_type)) { + if (is_real(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -377,7 +377,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::mul) { - if (is_real(stmt->ret_type.data_type)) { + if (is_real(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -395,7 +395,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm::Intrinsic::floor, {tlctx->get_data_type(ret_type)}, {div}); } } else if (op == BinaryOpType::div) { - if (is_real(stmt->ret_type.data_type)) { + if (is_real(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -513,7 +513,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } else if (is_comparison(op)) { llvm::Value *cmp = nullptr; - auto input_type = stmt->lhs->ret_type.data_type; + auto input_type = stmt->lhs->ret_type; if (op == BinaryOpType::cmp_eq) { if (is_real(input_type)) { cmp = builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); @@ -655,11 +655,11 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type.data_type == PrimitiveType::f32) + if (arg_stmt->ret_type == PrimitiveType::f32) value = builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64)); args.push_back(value); - formats += data_type_format(arg_stmt->ret_type.data_type); + formats += data_type_format(arg_stmt->ret_type); } else { auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); @@ -867,8 +867,8 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { - TI_ASSERT(!stmt->ret_type.data_type->is()); - dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); + TI_ASSERT(!stmt->ret_type->is()); + dest_ty = tlctx->get_data_type(stmt->ret_type); auto dest_bits = dest_ty->getPrimitiveSizeInBits(); auto truncated = builder->CreateTrunc( raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits)); @@ -881,8 +881,7 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) { TI_NOT_IMPLEMENTED } else { auto intermediate_bits = - tlctx->get_data_type(stmt->value->ret_type.data_type) - ->getPrimitiveSizeInBits(); + tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits(); llvm::Type *intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits); llvm::Type *dest_ty = tlctx->get_data_type(); @@ -928,8 +927,7 @@ void CodeGenLLVM::visit(AssertStmt *stmt) { // First convert the argument to an integral type with the same number of // bits: auto cast_type = llvm::Type::getIntNTy( - *llvm_context, - 8 * (std::size_t)data_type_size(arg->ret_type.data_type)); + *llvm_context, 8 * (std::size_t)data_type_size(arg->ret_type)); auto cast_int = builder->CreateBitCast(llvm_val[arg], cast_type); // Then zero-extend the conversion result into int64: @@ -953,7 +951,7 @@ void CodeGenLLVM::visit(SNodeOpStmt *stmt) { auto snode = stmt->snode; if (stmt->op_type == SNodeOpType::append) { TI_ASSERT(snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); + TI_ASSERT(stmt->ret_type == PrimitiveType::i32); llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "append", {llvm_val[stmt->val]}); } else if (stmt->op_type == SNodeOpType::length) { @@ -985,15 +983,15 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { for (int l = 0; l < stmt->width(); l++) { llvm::Value *old_value; if (stmt->op_type == AtomicOpType::add) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { + } else if (stmt->val->ret_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_add_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { + } else if (stmt->val->ret_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_add_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -1001,15 +999,15 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::min) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { + } else if (stmt->val->ret_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { + } else if (stmt->val->ret_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -1017,15 +1015,15 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::max) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { + } else if (stmt->val->ret_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { + } else if (stmt->val->ret_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -1033,7 +1031,7 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_and) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::And, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); @@ -1041,7 +1039,7 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_or) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Or, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); @@ -1049,7 +1047,7 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_xor) { - if (is_integral(stmt->val->ret_type.data_type)) { + if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Xor, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); @@ -1077,8 +1075,8 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { int width = stmt->width(); TI_ASSERT(width == 1); - llvm_val[stmt] = builder->CreateLoad( - tlctx->get_data_type(stmt->ret_type.data_type), llvm_val[stmt->ptr]); + llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), + llvm_val[stmt->ptr]); } void CodeGenLLVM::visit(ElementShuffleStmt *stmt){ @@ -1090,7 +1088,7 @@ void CodeGenLLVM::visit(ElementShuffleStmt *stmt){ }, "{"); if (stmt->pointer) { - emit("{} * const {} [{}] {};", data_type_name(stmt->ret_type.data_type), + emit("{} * const {} [{}] {};", data_type_name(stmt->ret_type), stmt->raw_name(), stmt->width(), init); } else { emit("const {} {} ({});", stmt->ret_data_type_name(), stmt->raw_name(), @@ -1228,7 +1226,7 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { sizes[i] = raw_arg; } - auto dt = stmt->ret_type.data_type.ptr_removed(); + auto dt = stmt->ret_type.ptr_removed(); auto base = builder->CreateBitCast( llvm_val[stmt->base_ptrs[0]], llvm::PointerType::get(tlctx->get_data_type(dt), 0)); @@ -1621,7 +1619,7 @@ void CodeGenLLVM::visit(GlobalTemporaryStmt *stmt) { TI_ASSERT(stmt->width() == 1); auto ptr_type = llvm::PointerType::get( - tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0); + tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); } @@ -1630,7 +1628,7 @@ void CodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) { TI_ASSERT(stmt->width() == 1); auto ptr = builder->CreateGEP(base, tlctx->get_constant(stmt->offset)); auto ptr_type = llvm::PointerType::get( - tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0); + tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } @@ -1641,7 +1639,7 @@ void CodeGenLLVM::visit(BlockLocalPtrStmt *stmt) { auto ptr = builder->CreateGEP( base, {tlctx->get_constant(0), llvm_val[stmt->offset]}); auto ptr_type = llvm::PointerType::get( - tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0); + tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } @@ -1678,8 +1676,8 @@ void CodeGenLLVM::visit(StackPushStmt *stmt) { auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); primal_ptr = builder->CreateBitCast( - primal_ptr, llvm::PointerType::get( - tlctx->get_data_type(stmt->ret_type.data_type), 0)); + primal_ptr, + llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0)); builder->CreateStore(llvm_val[stmt->v], primal_ptr); } @@ -1688,8 +1686,8 @@ void CodeGenLLVM::visit(StackLoadTopStmt *stmt) { auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); primal_ptr = builder->CreateBitCast( - primal_ptr, llvm::PointerType::get( - tlctx->get_data_type(stmt->ret_type.data_type), 0)); + primal_ptr, + llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0)); llvm_val[stmt] = builder->CreateLoad(primal_ptr); } @@ -1698,8 +1696,7 @@ void CodeGenLLVM::visit(StackLoadTopAdjStmt *stmt) { auto adjoint = call("stack_top_adjoint", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); adjoint = builder->CreateBitCast( - adjoint, llvm::PointerType::get( - tlctx->get_data_type(stmt->ret_type.data_type), 0)); + adjoint, llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0)); llvm_val[stmt] = builder->CreateLoad(adjoint); } @@ -1708,10 +1705,10 @@ void CodeGenLLVM::visit(StackAccAdjointStmt *stmt) { auto adjoint_ptr = call("stack_top_adjoint", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); adjoint_ptr = builder->CreateBitCast( - adjoint_ptr, llvm::PointerType::get( - tlctx->get_data_type(stack->ret_type.data_type), 0)); + adjoint_ptr, + llvm::PointerType::get(tlctx->get_data_type(stack->ret_type), 0)); auto old_val = builder->CreateLoad(adjoint_ptr); - TI_ASSERT(is_real(stmt->v->ret_type.data_type)); + TI_ASSERT(is_real(stmt->v->ret_type)); auto new_val = builder->CreateFAdd(old_val, llvm_val[stmt->v]); builder->CreateStore(new_val, adjoint_ptr); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 6b8daae92ac5d..d4c5b44ea2a0b 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -266,8 +266,7 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { if (result) { if (result->is()) { // special case of alloca (initialized to 0) - auto zero = - Stmt::make(TypedConstant(result->ret_type.data_type, 0)); + auto zero = Stmt::make(TypedConstant(result->ret_type, 0)); zero->repeat(result->width()); replace_with(i, std::move(zero), true); } else { diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 6dcbe90bf3984..ec93edbd1ae04 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -218,7 +218,7 @@ void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) { } std::string Stmt::type_hint() const { - if (ret_type.data_type == PrimitiveType::unknown) + if (ret_type == PrimitiveType::unknown) return ""; else return fmt::format("<{}> ", ret_type.to_string()); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 531408db99772..aa2197c854360 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -512,7 +512,7 @@ class Stmt : public IRNode { } DataType &element_type() { - return ret_type.data_type; + return ret_type; } std::string ret_data_type_name() const { diff --git a/taichi/ir/state_machine.cpp b/taichi/ir/state_machine.cpp index 1181dd873a0f1..8dfeaaa712334 100644 --- a/taichi/ir/state_machine.cpp +++ b/taichi/ir/state_machine.cpp @@ -134,7 +134,7 @@ void StateMachine::load(Stmt *load_stmt) { if (stored == never) { auto zero = load_stmt->insert_after_me(Stmt::make( - LaneAttribute(load_stmt->ret_type.data_type))); + LaneAttribute(load_stmt->ret_type))); zero->repeat(load_stmt->width()); int current_stmt_id = load_stmt->parent->locate(load_stmt); load_stmt->replace_with(zero); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 3e49c1be1d5af..0ecf3c666073f 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -121,7 +121,7 @@ class ArgLoadStmt : public Stmt { class RandStmt : public Stmt { public: RandStmt(DataType dt) { - ret_type.data_type = dt; + ret_type = dt; TI_STMT_REG_FIELDS; } @@ -1018,7 +1018,7 @@ class StackAllocaStmt : public Stmt { } std::size_t element_size_in_bytes() const { - return data_type_size(ret_type.data_type); + return data_type_size(ret_type); } std::size_t entry_size_in_bytes() const { diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index e98e77842a28f..fd65cbf1cbe40 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -16,7 +16,7 @@ TLANG_NAMESPACE_BEGIN #include "taichi/inc/data_type.inc.h" #undef PER_TYPE -DataType::DataType() : data_type(*this), ptr_(PrimitiveType::unknown.ptr_) { +DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { } DataType PrimitiveType::get(PrimitiveType::primitive_type t) { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 29f31f090103c..efcb14323c015 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -39,10 +39,10 @@ class DataType { public: DataType(); - DataType(Type *ptr) : data_type(*this), ptr_(ptr) { + DataType(Type *ptr) : ptr_(ptr) { } - DataType(const DataType &o) : data_type(*this), ptr_(o.ptr_) { + DataType(const DataType &o) : ptr_(o.ptr_) { } bool operator==(const DataType &o) const { @@ -67,7 +67,6 @@ class DataType { // Temporary API and members // for LegacyVectorType-compatibility int width{1}; - DataType &data_type; Type *operator->() const { return ptr_; diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index bcb13c3ed6830..9c79ee5cea0e4 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -10,10 +10,10 @@ TLANG_NAMESPACE_BEGIN class AlgSimp : public BasicStmtVisitor { private: void cast_to_result_type(Stmt *&a, Stmt *stmt) { - if (stmt->ret_type.data_type != a->ret_type.data_type) { + if (stmt->ret_type != a->ret_type) { auto cast = Stmt::make_typed(UnaryOpType::cast_value, a); - cast->cast_type = stmt->ret_type.data_type; - cast->ret_type.data_type = stmt->ret_type.data_type; + cast->cast_type = stmt->ret_type; + cast->ret_type = stmt->ret_type; a = cast.get(); modifier.insert_before(stmt, std::move(cast)); } @@ -30,8 +30,7 @@ class AlgSimp : public BasicStmtVisitor { } void visit(UnaryOpStmt *stmt) override { - if (stmt->is_cast() && - stmt->cast_type == stmt->operand->ret_type.data_type) { + if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) { stmt->replace_with(stmt->operand); modifier.erase(stmt); } @@ -68,21 +67,19 @@ class AlgSimp : public BasicStmtVisitor { // 1 * a -> a stmt->replace_with(stmt->rhs); modifier.erase(stmt); - } else if ((fast_math || is_integral(stmt->ret_type.data_type)) && + } else if ((fast_math || is_integral(stmt->ret_type)) && stmt->op_type == BinaryOpType::mul && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 - if (alg_is_zero(lhs) && - lhs->ret_type.data_type == stmt->ret_type.data_type) { + if (alg_is_zero(lhs) && lhs->ret_type == stmt->ret_type) { stmt->replace_with(stmt->lhs); modifier.erase(stmt); - } else if (alg_is_zero(rhs) && - rhs->ret_type.data_type == stmt->ret_type.data_type) { + } else if (alg_is_zero(rhs) && rhs->ret_type == stmt->ret_type) { stmt->replace_with(stmt->rhs); modifier.erase(stmt); } else { auto zero = Stmt::make( - LaneAttribute(stmt->ret_type.data_type)); + LaneAttribute(stmt->ret_type)); stmt->replace_with(zero.get()); modifier.insert_before(stmt, std::move(zero)); modifier.erase(stmt); @@ -95,22 +92,22 @@ class AlgSimp : public BasicStmtVisitor { a = stmt->rhs; cast_to_result_type(a, stmt); auto sum = Stmt::make(BinaryOpType::add, a, a); - sum->ret_type.data_type = a->ret_type.data_type; + sum->ret_type = a->ret_type; stmt->replace_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); modifier.erase(stmt); } else if (fast_math && stmt->op_type == BinaryOpType::div && rhs && - is_real(rhs->ret_type.data_type)) { + is_real(rhs->ret_type)) { if (alg_is_zero(rhs)) { TI_WARN("Potential division by 0"); } else { // a / const -> a * (1 / const) auto reciprocal = Stmt::make_typed( - LaneAttribute(rhs->ret_type.data_type)); - if (rhs->ret_type.data_type == PrimitiveType::f64) { + LaneAttribute(rhs->ret_type)); + if (rhs->ret_type == PrimitiveType::f64) { reciprocal->val[0].val_float64() = (float64)1.0 / rhs->val[0].val_float64(); - } else if (rhs->ret_type.data_type == PrimitiveType::f32) { + } else if (rhs->ret_type == PrimitiveType::f32) { reciprocal->val[0].val_float32() = (float32)1.0 / rhs->val[0].val_float32(); } else { @@ -118,7 +115,7 @@ class AlgSimp : public BasicStmtVisitor { } auto product = Stmt::make(BinaryOpType::mul, stmt->lhs, reciprocal.get()); - product->ret_type.data_type = stmt->ret_type.data_type; + product->ret_type = stmt->ret_type; stmt->replace_with(product.get()); modifier.insert_before(stmt, std::move(reciprocal)); modifier.insert_before(stmt, std::move(product)); @@ -144,7 +141,7 @@ class AlgSimp : public BasicStmtVisitor { auto a = stmt->lhs; cast_to_result_type(a, stmt); auto result = Stmt::make(UnaryOpType::sqrt, a); - result->ret_type.data_type = a->ret_type.data_type; + result->ret_type = a->ret_type; stmt->replace_with(result.get()); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); @@ -164,7 +161,7 @@ class AlgSimp : public BasicStmtVisitor { else { auto new_result = Stmt::make(BinaryOpType::mul, result, a_power_of_2); - new_result->ret_type.data_type = a->ret_type.data_type; + new_result->ret_type = a->ret_type; result = new_result.get(); modifier.insert_before(stmt, std::move(new_result)); } @@ -174,7 +171,7 @@ class AlgSimp : public BasicStmtVisitor { break; auto new_a_power = Stmt::make( BinaryOpType::mul, a_power_of_2, a_power_of_2); - new_a_power->ret_type.data_type = a->ret_type.data_type; + new_a_power->ret_type = a->ret_type; a_power_of_2 = new_a_power.get(); modifier.insert_before(stmt, std::move(new_a_power)); } @@ -190,7 +187,7 @@ class AlgSimp : public BasicStmtVisitor { auto new_exponent = Stmt::make(UnaryOpType::neg, rhs); auto a_to_n = Stmt::make(BinaryOpType::pow, stmt->lhs, new_exponent.get()); - a_to_n->ret_type.data_type = stmt->ret_type.data_type; + a_to_n->ret_type = stmt->ret_type; auto result = Stmt::make(BinaryOpType::div, one_raw, a_to_n.get()); stmt->replace_with(result.get()); diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index b85d55aef41c1..20f463d65ef03 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -153,7 +153,7 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { return; } // Create a alloc - auto alloc = Stmt::make(1, stmt->ret_type.data_type); + auto alloc = Stmt::make(1, stmt->ret_type); auto alloc_ptr = alloc.get(); TI_ASSERT(alloca_block); alloca_block->insert(std::move(alloc), 0); @@ -201,7 +201,7 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { }) .empty(); if (!load_only) { - auto dtype = alloc->ret_type.data_type; + auto dtype = alloc->ret_type; auto stack_alloca = Stmt::make( dtype, alloc->get_kernel()->program.config.ad_stack_size); auto stack_alloca_ptr = stack_alloca.get(); @@ -381,7 +381,7 @@ class MakeAdjoint : public IRVisitor { return; // primal may be int variable if (alloca_->is()) { auto alloca = alloca_->cast(); - if (needs_grad(alloca->ret_type.data_type)) { + if (needs_grad(alloca->ret_type)) { insert(alloca, load(value)); } } else { @@ -394,7 +394,7 @@ class MakeAdjoint : public IRVisitor { } Stmt *adjoint(Stmt *stmt) { - if (!needs_grad(stmt->ret_type.data_type)) { + if (!needs_grad(stmt->ret_type)) { return constant(0); } if (adjoint_stmt.find(stmt) == adjoint_stmt.end()) { @@ -404,7 +404,7 @@ class MakeAdjoint : public IRVisitor { // auto alloca = // Stmt::make(1, get_current_program().config.gradient_dt); // maybe it's better to use the statement data type than the default type - auto alloca = Stmt::make(1, stmt->ret_type.data_type); + auto alloca = Stmt::make(1, stmt->ret_type); adjoint_stmt[stmt] = alloca.get(); alloca_block->insert(std::move(alloca), 0); } @@ -461,8 +461,7 @@ class MakeAdjoint : public IRVisitor { accumulate(stmt->operand, mul(adjoint(stmt), div(constant(0.5f), sqrt(stmt->operand)))); } else if (stmt->op_type == UnaryOpType::cast_value) { - if (is_real(stmt->cast_type) && - is_real(stmt->operand->ret_type.data_type)) { + if (is_real(stmt->cast_type) && is_real(stmt->operand->ret_type)) { accumulate(stmt->operand, adjoint(stmt)); } } else if (stmt->op_type == UnaryOpType::logic_not) { @@ -505,7 +504,7 @@ class MakeAdjoint : public IRVisitor { bin->op_type == BinaryOpType::max) { auto cmp = bin->op_type == BinaryOpType::min ? cmp_lt(bin->lhs, bin->rhs) : cmp_lt(bin->rhs, bin->lhs); - auto zero = insert(TypedConstant(bin->ret_type.data_type)); + auto zero = insert(TypedConstant(bin->ret_type)); accumulate(bin->lhs, sel(cmp, adjoint(bin), zero)); accumulate(bin->rhs, sel(cmp, zero, adjoint(bin))); } else if (bin->op_type == BinaryOpType::floordiv) { @@ -520,7 +519,7 @@ class MakeAdjoint : public IRVisitor { void visit(TernaryOpStmt *stmt) override { TI_ASSERT(stmt->op_type == TernaryOpType::select); - auto zero = insert(TypedConstant(stmt->ret_type.data_type)); + auto zero = insert(TypedConstant(stmt->ret_type)); accumulate(stmt->op2, insert(TernaryOpType::select, stmt->op1, load(adjoint(stmt)), zero)); @@ -612,11 +611,11 @@ class MakeAdjoint : public IRVisitor { } void visit(LocalLoadStmt *stmt) override { - // TI_ASSERT(!needs_grad(stmt->ret_type.data_type)); + // TI_ASSERT(!needs_grad(stmt->ret_type)); } void visit(StackLoadTopStmt *stmt) override { - if (needs_grad(stmt->ret_type.data_type)) + if (needs_grad(stmt->ret_type)) insert(stmt->stack, load(adjoint(stmt))); } diff --git a/taichi/transforms/binary_op_simplify.cpp b/taichi/transforms/binary_op_simplify.cpp index a3a478469cb34..aa9936f8ffa7b 100644 --- a/taichi/transforms/binary_op_simplify.cpp +++ b/taichi/transforms/binary_op_simplify.cpp @@ -29,7 +29,7 @@ class BinaryOpSimp : public BasicStmtVisitor { } // Disable other optimizations if fast_math=True and the data type is not // integral. - if (!fast_math && !is_integral(stmt->ret_type.data_type)) { + if (!fast_math && !is_integral(stmt->ret_type)) { return; } auto binary_lhs = stmt->lhs->cast(); @@ -45,7 +45,7 @@ class BinaryOpSimp : public BasicStmtVisitor { auto op2 = stmt->op_type; // Disables (a / b) * c -> a / (b / c), (a * b) / c -> a * (b / c) // when the data type is integral. - if (is_integral(stmt->ret_type.data_type) && + if (is_integral(stmt->ret_type) && ((op1 == BinaryOpType::div && op2 == BinaryOpType::mul) || (op1 == BinaryOpType::mul && op2 == BinaryOpType::div))) { return; @@ -57,10 +57,10 @@ class BinaryOpSimp : public BasicStmtVisitor { // stmt = a op1 (b op2 c) if (can_rearrange_associative(op1, op2, new_op2)) { auto bin_op = Stmt::make(new_op2, const_lhs_rhs, const_rhs); - bin_op->ret_type.data_type = stmt->ret_type.data_type; + bin_op->ret_type = stmt->ret_type; auto new_stmt = Stmt::make(op1, binary_lhs->lhs, bin_op.get()); - new_stmt->ret_type.data_type = stmt->ret_type.data_type; + new_stmt->ret_type = stmt->ret_type; modifier.insert_before(stmt, std::move(bin_op)); stmt->replace_with(new_stmt.get()); diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index ebaa5c02613fd..eaaf47e0d1ba3 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -133,7 +133,7 @@ class ConstantFold : public BasicStmtVisitor { return; if (stmt->width() != 1) return; - auto dst_type = stmt->ret_type.data_type; + auto dst_type = stmt->ret_type; TypedConstant new_constant(dst_type); if (jit_evaluate_binary_op(new_constant, stmt, lhs->val[0], rhs->val[0])) { auto evaluated = @@ -145,8 +145,7 @@ class ConstantFold : public BasicStmtVisitor { } void visit(UnaryOpStmt *stmt) override { - if (stmt->is_cast() && - stmt->cast_type == stmt->operand->ret_type.data_type) { + if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) { stmt->replace_with(stmt->operand); modifier.erase(stmt); return; @@ -156,7 +155,7 @@ class ConstantFold : public BasicStmtVisitor { return; if (stmt->width() != 1) return; - auto dst_type = stmt->ret_type.data_type; + auto dst_type = stmt->ret_type; TypedConstant new_constant(dst_type); if (jit_evaluate_unary_op(new_constant, stmt, operand->val[0])) { auto evaluated = diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 79285c5519b76..6a792a2acbddc 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -183,7 +183,7 @@ class LowerAccess : public IRVisitor { lanes.push_back(VectorElement(lowered_pointers[i], 0)); } auto merge = Stmt::make(lanes, true); - merge->ret_type.data_type = ptr->snodes[0]->dt; + merge->ret_type = ptr->snodes[0]->dt; lowered.push_back(std::move(merge)); return lowered; } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 32474778e4e61..d642b6522b25a 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -69,7 +69,7 @@ class LowerAST : public IRVisitor { auto ident = stmt->ident; TI_ASSERT(block->local_var_to_stmt.find(ident) == block->local_var_to_stmt.end()); - auto lowered = std::make_unique(stmt->ret_type.data_type); + auto lowered = std::make_unique(stmt->ret_type); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); throw IRModified(); diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index 7c291a61e6ec5..d28c3792d39d9 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -113,7 +113,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { // TODO: sort thread local storage variables according to dtype_size to // reduce buffer fragmentation. for (auto dest : valid_reduction_values) { - auto data_type = dest->ret_type.data_type.ptr_removed(); + auto data_type = dest->ret_type.ptr_removed(); auto dtype_size = data_type_size(data_type); // Step 1: // Create thread local storage diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 6be056bbcf6b6..461ed943bbae4 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -261,7 +261,7 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { std::size_t allocate_global(DataType type) { TI_ASSERT(type.width == 1); auto ret = global_offset; - global_offset += data_type_size(type.data_type); + global_offset += data_type_size(type); TI_ASSERT(global_offset < taichi_global_tmp_buffer_size); return ret; } @@ -417,7 +417,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto ptr = replacement.push_back( local_to_global_offset[stmt], ret_type); LaneAttribute zeros(std::vector( - stmt->width(), TypedConstant(stmt->ret_type.data_type))); + stmt->width(), TypedConstant(stmt->ret_type))); auto const_zeros = replacement.push_back(zeros); replacement.push_back(ptr, const_zeros); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 9ec76b1f99922..2d0b046721787 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -184,7 +184,7 @@ class BasicBlockSimplify : public IRVisitor { if (k == num_loop_vars - 1) { auto load = stmt->insert_before_me( Stmt::make(current_struct_for, k)); - load->ret_type.data_type = PrimitiveType::i32; + load->ret_type = PrimitiveType::i32; stmt->input = load; int64 bound = 1LL << stmt->bit_end; auto offset = (((int64)diff.low % bound + bound) % bound) & @@ -215,12 +215,12 @@ class BasicBlockSimplify : public IRVisitor { // insert constant auto load = stmt->insert_before_me( Stmt::make(current_struct_for, k)); - load->ret_type.data_type = PrimitiveType::i32; + load->ret_type = PrimitiveType::i32; auto constant = stmt->insert_before_me( Stmt::make(TypedConstant(diff.low))); auto add = stmt->insert_before_me( Stmt::make(BinaryOpType::add, load, constant)); - add->ret_type.data_type = PrimitiveType::i32; + add->ret_type = PrimitiveType::i32; stmt->input = add; } stmt->simplified = true; diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index db4d08aa83447..dabf894133c52 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -41,7 +41,7 @@ class TypeCheck : public IRVisitor { void visit(IfStmt *if_stmt) { // TODO: use PrimitiveType::u1 when it's supported TI_ASSERT_INFO( - if_stmt->cond->ret_type.data_type == PrimitiveType::i32, + if_stmt->cond->ret_type == PrimitiveType::i32, "`if` conditions must be of type int32, consider using `if x != 0:` " "instead of `if x:` for float values."); if (if_stmt->true_statements) @@ -66,7 +66,7 @@ class TypeCheck : public IRVisitor { if (stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) { // TODO: make sure the ptr_removed type is indeed a numerical type TI_WARN("[{}] Atomic add ({} to {}) may lose precision.", stmt->name(), - data_type_name(stmt->val->ret_type.data_type), + data_type_name(stmt->val->ret_type), data_type_name(stmt->dest->ret_type.ptr_removed())); stmt->val = insert_type_cast_before(stmt, stmt->val, stmt->dest->ret_type.ptr_removed()); @@ -84,19 +84,19 @@ class TypeCheck : public IRVisitor { } void visit(LocalStoreStmt *stmt) { - if (stmt->ptr->ret_type.data_type == PrimitiveType::unknown) { + if (stmt->ptr->ret_type == PrimitiveType::unknown) { // Infer data type for alloca stmt->ptr->ret_type = stmt->data->ret_type; } - auto common_container_type = promoted_type(stmt->ptr->ret_type.data_type, - stmt->data->ret_type.data_type); + auto common_container_type = + promoted_type(stmt->ptr->ret_type, stmt->data->ret_type); auto old_data = stmt->data; - if (stmt->ptr->ret_type.data_type != stmt->data->ret_type.data_type) { - stmt->data = insert_type_cast_before(stmt, stmt->data, - stmt->ptr->ret_type.data_type); + if (stmt->ptr->ret_type != stmt->data->ret_type) { + stmt->data = + insert_type_cast_before(stmt, stmt->data, stmt->ptr->ret_type); } - if (stmt->ptr->ret_type.data_type != common_container_type) { + if (stmt->ptr->ret_type != common_container_type) { TI_WARN( "[{}] Local store may lose precision (target = {}, value = {}) at", stmt->name(), stmt->ptr->ret_data_type_name(), @@ -122,7 +122,7 @@ class TypeCheck : public IRVisitor { void visit(GlobalPtrStmt *stmt) { stmt->ret_type.set_is_pointer(true); if (stmt->snodes) - stmt->ret_type.data_type = stmt->snodes[0]->dt; + stmt->ret_type = stmt->snodes[0]->dt; else TI_WARN("[{}] Type inference failed: snode is nullptr.", stmt->name()); for (int l = 0; l < stmt->snodes.size(); l++) { @@ -135,7 +135,7 @@ class TypeCheck : public IRVisitor { } } for (int i = 0; i < stmt->indices.size(); i++) { - if (!is_integral(stmt->indices[i]->ret_type.data_type)) { + if (!is_integral(stmt->indices[i]->ret_type)) { TI_WARN( "[{}] Field index {} not integral, casting into int32 implicitly", stmt->name(), i); @@ -147,11 +147,10 @@ class TypeCheck : public IRVisitor { } void visit(GlobalStoreStmt *stmt) { - auto promoted = promoted_type(stmt->ptr->ret_type.ptr_removed(), - stmt->data->ret_type.data_type); + auto promoted = + promoted_type(stmt->ptr->ret_type.ptr_removed(), stmt->data->ret_type); auto input_type = stmt->data->ret_data_type_name(); - if (stmt->ptr->ret_type.data_type.ptr_removed() != - stmt->data->ret_type.data_type) { + if (stmt->ptr->ret_type.ptr_removed() != stmt->data->ret_type) { stmt->data = insert_type_cast_before(stmt, stmt->data, stmt->ptr->ret_type.ptr_removed()); } @@ -180,9 +179,9 @@ class TypeCheck : public IRVisitor { void visit(UnaryOpStmt *stmt) { stmt->ret_type = stmt->operand->ret_type; if (stmt->is_cast()) { - stmt->ret_type.data_type = stmt->cast_type; + stmt->ret_type = stmt->cast_type; } - if (!is_real(stmt->operand->ret_type.data_type)) { + if (!is_real(stmt->operand->ret_type)) { if (is_trigonometric(stmt->op_type)) { TI_ERROR("[{}] Trigonometric operator takes real inputs only. At {}", stmt->name(), stmt->tb); @@ -242,32 +241,31 @@ class TypeCheck : public IRVisitor { TI_WARN("Compilation stopped due to type mismatch."); throw std::runtime_error("Binary operator type mismatch"); }; - if (stmt->lhs->ret_type.data_type == PrimitiveType::unknown && - stmt->rhs->ret_type.data_type == PrimitiveType::unknown) + if (stmt->lhs->ret_type == PrimitiveType::unknown && + stmt->rhs->ret_type == PrimitiveType::unknown) error(); // lower truediv into div if (stmt->op_type == BinaryOpType::truediv) { auto default_fp = config.default_fp; - if (!is_real(stmt->lhs->ret_type.data_type)) { + if (!is_real(stmt->lhs->ret_type)) { cast(stmt->lhs, default_fp); } - if (!is_real(stmt->rhs->ret_type.data_type)) { + if (!is_real(stmt->rhs->ret_type)) { cast(stmt->rhs, default_fp); } stmt->op_type = BinaryOpType::div; } - if (stmt->lhs->ret_type.data_type != stmt->rhs->ret_type.data_type) { - auto ret_type = promoted_type(stmt->lhs->ret_type.data_type, - stmt->rhs->ret_type.data_type); - if (ret_type != stmt->lhs->ret_type.data_type) { + if (stmt->lhs->ret_type != stmt->rhs->ret_type) { + auto ret_type = promoted_type(stmt->lhs->ret_type, stmt->rhs->ret_type); + if (ret_type != stmt->lhs->ret_type) { // promote rhs auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type); stmt->lhs = cast_stmt; } - if (ret_type != stmt->rhs->ret_type.data_type) { + if (ret_type != stmt->rhs->ret_type) { // promote rhs auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type); stmt->rhs = cast_stmt; @@ -276,16 +274,14 @@ class TypeCheck : public IRVisitor { bool matching = true; matching = matching && (stmt->lhs->ret_type.width == stmt->rhs->ret_type.width); - matching = - matching && (stmt->lhs->ret_type.data_type != PrimitiveType::unknown); - matching = - matching && (stmt->rhs->ret_type.data_type != PrimitiveType::unknown); + matching = matching && (stmt->lhs->ret_type != PrimitiveType::unknown); + matching = matching && (stmt->rhs->ret_type != PrimitiveType::unknown); matching = matching && (stmt->lhs->ret_type == stmt->rhs->ret_type); if (!matching) { error(); } if (binary_is_bitwise(stmt->op_type)) { - if (!is_integral(stmt->lhs->ret_type.data_type)) { + if (!is_integral(stmt->lhs->ret_type)) { error("Error: bitwise operations can only apply to integral types."); } } @@ -299,16 +295,15 @@ class TypeCheck : public IRVisitor { void visit(TernaryOpStmt *stmt) { if (stmt->op_type == TernaryOpType::select) { - auto ret_type = promoted_type(stmt->op2->ret_type.data_type, - stmt->op3->ret_type.data_type); - TI_ASSERT(stmt->op1->ret_type.data_type == PrimitiveType::i32) + auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type); + TI_ASSERT(stmt->op1->ret_type == PrimitiveType::i32) TI_ASSERT(stmt->op1->ret_type.width == stmt->op2->ret_type.width); TI_ASSERT(stmt->op2->ret_type.width == stmt->op3->ret_type.width); - if (ret_type != stmt->op2->ret_type.data_type) { + if (ret_type != stmt->op2->ret_type) { auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type); stmt->op2 = cast_stmt; } - if (ret_type != stmt->op3->ret_type.data_type) { + if (ret_type != stmt->op3->ret_type) { auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type); stmt->op3 = cast_stmt; } @@ -333,7 +328,7 @@ class TypeCheck : public IRVisitor { // TODO: Maybe have a type_inference() pass, which takes in the args/rets // defined by the kernel. After that, type_check() pass will purely do // verification, without modifying any types. - TI_ASSERT(rt.data_type != PrimitiveType::unknown); + TI_ASSERT(rt != PrimitiveType::unknown); TI_ASSERT(rt.width == 1); stmt->ret_type.set_is_pointer(stmt->is_ptr); } @@ -341,14 +336,14 @@ class TypeCheck : public IRVisitor { void visit(KernelReturnStmt *stmt) { // TODO: Support stmt->ret_id? const auto &rt = stmt->ret_type; - TI_ASSERT(stmt->value->element_type() == rt.data_type); + TI_ASSERT(stmt->value->element_type() == rt); TI_ASSERT(rt.width == 1); } void visit(ExternalPtrStmt *stmt) { stmt->ret_type.set_is_pointer(true); - stmt->ret_type = LegacyVectorType(stmt->base_ptrs.size(), - stmt->base_ptrs[0]->ret_type.data_type); + stmt->ret_type = + LegacyVectorType(stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type); } void visit(LoopIndexStmt *stmt) { @@ -388,15 +383,15 @@ class TypeCheck : public IRVisitor { } void visit(LinearizeStmt *stmt) { - stmt->ret_type.data_type = PrimitiveType::i32; + stmt->ret_type = PrimitiveType::i32; } void visit(IntegerOffsetStmt *stmt) { - stmt->ret_type.data_type = PrimitiveType::i32; + stmt->ret_type = PrimitiveType::i32; } void visit(StackAllocaStmt *stmt) { - stmt->ret_type.data_type = stmt->dt; + stmt->ret_type = stmt->dt; // ret_type stands for its element type. stmt->ret_type.set_is_pointer(false); }