From 6349d60487ee5a1db408f03b93a007b4382de0b2 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Mon, 13 Jun 2022 12:48:00 +0800 Subject: [PATCH] [type] [refactor] Misc improvements to quant codegen (#5129) * Replace is_custom_type() with is_quant() * Rename two functions * Use get_constant() if possible * Rename two metal functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/lang/matrix.py | 4 +-- taichi/backends/metal/codegen_metal.cpp | 22 ++++++++-------- .../metal/shaders/snode_bit_pointer.metal.h | 2 +- taichi/codegen/codegen_llvm.h | 8 +++--- taichi/codegen/codegen_llvm_quant.cpp | 25 +++++++++---------- taichi/ir/frontend_ir.cpp | 6 ++--- taichi/ir/type_utils.h | 2 +- taichi/python/export_lang.cpp | 2 +- taichi/transforms/type_check.cpp | 2 +- 9 files changed, 35 insertions(+), 38 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 1c7add5cd19c3..103246daeb381 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1457,8 +1457,8 @@ def _calc_dynamic_index_stride(self): return length = len(paths[0]) if any( - len(path) != length or ti_core.is_custom_type(path[length - - 1]._dtype) + len(path) != length or ti_core.is_quant(path[length - + 1]._dtype) for path in paths): return for i in range(length): diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index a020c4635d26a..e3c9aa95f282d 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -981,7 +981,7 @@ class KernelCodegenImpl : public IRVisitor { validate_cft_for_metal(cft); auto *digits_cit = cft->get_digits_type()->as(); cit = digits_cit; - store_value_expr = construct_float_to_custom_int_expr( + store_value_expr = construct_quant_fixed_to_quant_int_expr( stmt->val, cft->get_scale(), digits_cit); } else { TI_NOT_IMPLEMENTED; @@ -1004,10 +1004,10 @@ class KernelCodegenImpl : public IRVisitor { TI_ASSERT(ptr_type->is_bit_pointer()); auto *pointee_type = ptr_type->get_pointee_type(); if (auto *cit = pointee_type->cast()) { - return construct_load_as_custom_int(stmt->src, cit); + return construct_load_quant_int(stmt->src, cit); } else if (auto *cft = pointee_type->cast()) { validate_cft_for_metal(cft); - const auto loaded = construct_load_as_custom_int( + const auto loaded = construct_load_quant_int( stmt->src, cft->get_digits_type()->as()); // Computes `float(digits_expr) * scale` // See LLVM backend's reconstruct_quant_fixed() @@ -1033,8 +1033,8 @@ class KernelCodegenImpl : public IRVisitor { val_expr = stmt->val->raw_name(); } else if (auto *cft = pointee_type->cast()) { cit = cft->get_digits_type()->as(); - val_expr = - construct_float_to_custom_int_expr(stmt->val, cft->get_scale(), cit); + val_expr = construct_quant_fixed_to_quant_int_expr(stmt->val, + cft->get_scale(), cit); } else { TI_NOT_IMPLEMENTED; } @@ -1051,7 +1051,7 @@ class KernelCodegenImpl : public IRVisitor { } // Returns the expression of `int(val_stmt * (1.0f / scale) + 0.5f)` - std::string construct_float_to_custom_int_expr( + std::string construct_quant_fixed_to_quant_int_expr( const Stmt *val_stmt, float64 scale, CustomIntType *digits_cit) const { @@ -1062,14 +1062,14 @@ class KernelCodegenImpl : public IRVisitor { // variables) because |val_stmt| could be used multiple times. If the // intermediate variables are named based on |val_stmt|, it would result in // symbol redefinitions. - return fmt::format("mtl_float_to_custom_int<{}>(/*inv_scale=*/{} * {})", - metal_data_type_name(compute_dt), inv_scale, - val_stmt->raw_name()); + return fmt::format( + "mtl_quant_fixed_to_quant_int<{}>(/*inv_scale=*/{} * {})", + metal_data_type_name(compute_dt), inv_scale, val_stmt->raw_name()); } // Returns expression of the loaded integer. - std::string construct_load_as_custom_int(const Stmt *bit_ptr_stmt, - CustomIntType *cit) const { + std::string construct_load_quant_int(const Stmt *bit_ptr_stmt, + CustomIntType *cit) const { DataType compute_dt(cit->get_compute_type()->as()); const auto num_bits = cit->get_num_bits(); if (is_full_bits(num_bits)) { diff --git a/taichi/backends/metal/shaders/snode_bit_pointer.metal.h b/taichi/backends/metal/shaders/snode_bit_pointer.metal.h index 63310fe1ff39f..4b390e5ceac42 100644 --- a/taichi/backends/metal/shaders/snode_bit_pointer.metal.h +++ b/taichi/backends/metal/shaders/snode_bit_pointer.metal.h @@ -38,7 +38,7 @@ STR( // |f| should already be scaled. |C| is the compute type. template - C mtl_float_to_custom_int(float f) { + C mtl_quant_fixed_to_quant_int(float f) { // Branch free implementation of `f + sign(f) * 0.5`. // See rounding_prepare_f* in taichi/runtime/llvm/runtime.cpp const int32_t delta_bits = diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index af7a0bf091740..662ef2d9d7199 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -265,9 +265,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(GlobalStoreStmt *stmt) override; - llvm::Value *custom_type_to_bits(llvm::Value *val, - Type *input_type, - Type *output_type); + llvm::Value *quant_int_or_quant_fixed_to_bits(llvm::Value *val, + Type *input_type, + Type *output_type); void visit(BitStructStoreStmt *stmt) override; @@ -399,7 +399,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *extract_digits_from_f32(llvm::Value *f, bool full); - llvm::Value *extract_digits_from_quant_float_with_shared_exponent( + llvm::Value *extract_digits_from_f32_with_shared_exponent( llvm::Value *f, llvm::Value *shared_exp); diff --git a/taichi/codegen/codegen_llvm_quant.cpp b/taichi/codegen/codegen_llvm_quant.cpp index ec3c764bfd78d..6e1aaa175ebde 100644 --- a/taichi/codegen/codegen_llvm_quant.cpp +++ b/taichi/codegen/codegen_llvm_quant.cpp @@ -51,9 +51,8 @@ llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(CustomFloatType *cft, // Compute int(real * (1.0 / scale) + 0.5) auto s_numeric = 1.0 / cft->get_scale(); auto compute_type = cft->get_compute_type(); - s = builder->CreateFPCast( - llvm::ConstantFP::get(*llvm_context, llvm::APFloat(s_numeric)), - llvm_type(compute_type)); + s = builder->CreateFPCast(tlctx->get_constant(s_numeric), + llvm_type(compute_type)); auto input_real = builder->CreateFPCast(real, llvm_type(compute_type)); auto scaled = builder->CreateFMul(input_real, s); @@ -128,9 +127,9 @@ llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, tlctx->get_constant(0)); } -llvm::Value *CodeGenLLVM::custom_type_to_bits(llvm::Value *val, - Type *input_type, - Type *output_type) { +llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, + Type *input_type, + Type *output_type) { CustomIntType *cit = nullptr; if (auto cft = input_type->cast()) { TI_ASSERT(cft->get_exponent_type() == nullptr); @@ -262,7 +261,8 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { val = builder->CreateBitCast(val, llvm_type(bit_struct_physical_type)); val = builder->CreateShl(val, digits_snode->bit_offset); } else { - val = custom_type_to_bits(val, dtype, bit_struct_physical_type); + val = quant_int_or_quant_fixed_to_bits(val, dtype, + bit_struct_physical_type); val = builder->CreateShl(val, bit_struct_snode->ch[ch_id]->bit_offset); } @@ -374,8 +374,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( for (int c = 0; c < (int)exp->exponent_users.size(); c++) { auto user = exp->exponent_users[c]; auto ch_id = snode->child_id(user); - auto digits = extract_digits_from_quant_float_with_shared_exponent( - floats[c], max_exp_bits); + auto digits = + extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits); auto digits_snode = snode->ch[ch_id].get(); auto cft = digits_snode->dt->as(); auto digits_bit_offset = digits_snode->bit_offset; @@ -435,7 +435,7 @@ llvm::Value *CodeGenLLVM::extract_digits_from_f32(llvm::Value *f, bool full) { return digits; } -llvm::Value *CodeGenLLVM::extract_digits_from_quant_float_with_shared_exponent( +llvm::Value *CodeGenLLVM::extract_digits_from_f32_with_shared_exponent( llvm::Value *f, llvm::Value *shared_exp) { auto exp = extract_exponent_from_f32(f); @@ -518,13 +518,12 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, // Compute float(digits) * scale llvm::Value *cast = nullptr; auto compute_type = cft->get_compute_type()->as(); - if (cft->get_digits_type()->cast()->get_is_signed()) { + if (cft->get_is_signed()) { cast = builder->CreateSIToFP(digits, llvm_type(compute_type)); } else { cast = builder->CreateUIToFP(digits, llvm_type(compute_type)); } - llvm::Value *s = - llvm::ConstantFP::get(*llvm_context, llvm::APFloat(cft->get_scale())); + llvm::Value *s = tlctx->get_constant(cft->get_scale()); s = builder->CreateFPCast(s, llvm_type(compute_type)); return builder->CreateFMul(cast, s); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 93fca8a78e6ef..46cf48778ce1b 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -519,10 +519,8 @@ void AtomicOpExpression::type_check(CompileConfig *) { }; if (!val->ret_type->is()) error(); - if (auto cit = dest->ret_type->cast()) { - ret_type = cit->get_compute_type(); - } else if (auto cft = dest->ret_type->cast()) { - ret_type = cft->get_compute_type(); + if (is_quant(dest->ret_type)) { + ret_type = dest->ret_type->get_compute_type(); } else if (dest->ret_type->is()) { ret_type = dest->ret_type; } else { diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index b5e10ed0f29c5..da1087bc9d7e5 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -73,7 +73,7 @@ inline PrimitiveTypeID get_primitive_data_type() { } } -inline bool is_custom_type(DataType dt) { +inline bool is_quant(DataType dt) { return dt->is() || dt->is(); } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index ffb7161dda1eb..fd1443c9f849b 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -865,7 +865,7 @@ void export_lang(py::module &m) { #undef PER_TYPE m.def("data_type_size", data_type_size); - m.def("is_custom_type", is_custom_type); + m.def("is_quant", is_quant); m.def("is_integral", is_integral); m.def("is_signed", is_signed); m.def("is_real", is_real); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index a7cc5b60ecfb8..10132bafd3d40 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -23,7 +23,7 @@ class TypeCheck : public IRVisitor { Stmt *&val, const std::string &stmt_name) { auto dst_type = dst->ret_type.ptr_removed(); - if (dst_type->is() || dst_type->is()) { + if (is_quant(dst_type)) { // We force the value type to be the compute_type of the bit pointer. // Casting from compute_type to physical_type is handled in codegen. dst_type = dst_type->get_compute_type();