Skip to content

Commit

Permalink
[type] [refactor] Misc improvements to quant codegen (#5129)
Browse files Browse the repository at this point in the history
* Replace is_custom_type() with is_quant()

* Rename two functions

* Use get_constant() if possible

* Rename two metal functions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jun 13, 2022
1 parent aba2871 commit 6349d60
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 38 deletions.
4 changes: 2 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,8 +1457,8 @@ def _calc_dynamic_index_stride(self):
return
length = len(paths[0])
if any(
len(path) != length or ti_core.is_custom_type(path[length -
1]._dtype)
len(path) != length or ti_core.is_quant(path[length -
1]._dtype)
for path in paths):
return
for i in range(length):
Expand Down
22 changes: 11 additions & 11 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ class KernelCodegenImpl : public IRVisitor {
validate_cft_for_metal(cft);
auto *digits_cit = cft->get_digits_type()->as<CustomIntType>();
cit = digits_cit;
store_value_expr = construct_float_to_custom_int_expr(
store_value_expr = construct_quant_fixed_to_quant_int_expr(
stmt->val, cft->get_scale(), digits_cit);
} else {
TI_NOT_IMPLEMENTED;
Expand All @@ -1004,10 +1004,10 @@ class KernelCodegenImpl : public IRVisitor {
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
if (auto *cit = pointee_type->cast<CustomIntType>()) {
return construct_load_as_custom_int(stmt->src, cit);
return construct_load_quant_int(stmt->src, cit);
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
const auto loaded = construct_load_as_custom_int(
const auto loaded = construct_load_quant_int(
stmt->src, cft->get_digits_type()->as<CustomIntType>());
// Computes `float(digits_expr) * scale`
// See LLVM backend's reconstruct_quant_fixed()
Expand All @@ -1033,8 +1033,8 @@ class KernelCodegenImpl : public IRVisitor {
val_expr = stmt->val->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
cit = cft->get_digits_type()->as<CustomIntType>();
val_expr =
construct_float_to_custom_int_expr(stmt->val, cft->get_scale(), cit);
val_expr = construct_quant_fixed_to_quant_int_expr(stmt->val,
cft->get_scale(), cit);
} else {
TI_NOT_IMPLEMENTED;
}
Expand All @@ -1051,7 +1051,7 @@ class KernelCodegenImpl : public IRVisitor {
}

// Returns the expression of `int(val_stmt * (1.0f / scale) + 0.5f)`
std::string construct_float_to_custom_int_expr(
std::string construct_quant_fixed_to_quant_int_expr(
const Stmt *val_stmt,
float64 scale,
CustomIntType *digits_cit) const {
Expand All @@ -1062,14 +1062,14 @@ class KernelCodegenImpl : public IRVisitor {
// variables) because |val_stmt| could be used multiple times. If the
// intermediate variables are named based on |val_stmt|, it would result in
// symbol redefinitions.
return fmt::format("mtl_float_to_custom_int<{}>(/*inv_scale=*/{} * {})",
metal_data_type_name(compute_dt), inv_scale,
val_stmt->raw_name());
return fmt::format(
"mtl_quant_fixed_to_quant_int<{}>(/*inv_scale=*/{} * {})",
metal_data_type_name(compute_dt), inv_scale, val_stmt->raw_name());
}

// Returns expression of the loaded integer.
std::string construct_load_as_custom_int(const Stmt *bit_ptr_stmt,
CustomIntType *cit) const {
std::string construct_load_quant_int(const Stmt *bit_ptr_stmt,
CustomIntType *cit) const {
DataType compute_dt(cit->get_compute_type()->as<PrimitiveType>());
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/shaders/snode_bit_pointer.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ STR(

// |f| should already be scaled. |C| is the compute type.
template <typename C>
C mtl_float_to_custom_int(float f) {
C mtl_quant_fixed_to_quant_int(float f) {
// Branch free implementation of `f + sign(f) * 0.5`.
// See rounding_prepare_f* in taichi/runtime/llvm/runtime.cpp
const int32_t delta_bits =
Expand Down
8 changes: 4 additions & 4 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(GlobalStoreStmt *stmt) override;

llvm::Value *custom_type_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type);
llvm::Value *quant_int_or_quant_fixed_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type);

void visit(BitStructStoreStmt *stmt) override;

Expand Down Expand Up @@ -399,7 +399,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *extract_digits_from_f32(llvm::Value *f, bool full);

llvm::Value *extract_digits_from_quant_float_with_shared_exponent(
llvm::Value *extract_digits_from_f32_with_shared_exponent(
llvm::Value *f,
llvm::Value *shared_exp);

Expand Down
25 changes: 12 additions & 13 deletions taichi/codegen/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(CustomFloatType *cft,
// Compute int(real * (1.0 / scale) + 0.5)
auto s_numeric = 1.0 / cft->get_scale();
auto compute_type = cft->get_compute_type();
s = builder->CreateFPCast(
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(s_numeric)),
llvm_type(compute_type));
s = builder->CreateFPCast(tlctx->get_constant(s_numeric),
llvm_type(compute_type));
auto input_real = builder->CreateFPCast(real, llvm_type(compute_type));
auto scaled = builder->CreateFMul(input_real, s);

Expand Down Expand Up @@ -128,9 +127,9 @@ llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
tlctx->get_constant(0));
}

llvm::Value *CodeGenLLVM::custom_type_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type) {
llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type) {
CustomIntType *cit = nullptr;
if (auto cft = input_type->cast<CustomFloatType>()) {
TI_ASSERT(cft->get_exponent_type() == nullptr);
Expand Down Expand Up @@ -262,7 +261,8 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
val = builder->CreateBitCast(val, llvm_type(bit_struct_physical_type));
val = builder->CreateShl(val, digits_snode->bit_offset);
} else {
val = custom_type_to_bits(val, dtype, bit_struct_physical_type);
val = quant_int_or_quant_fixed_to_bits(val, dtype,
bit_struct_physical_type);
val = builder->CreateShl(val, bit_struct_snode->ch[ch_id]->bit_offset);
}

Expand Down Expand Up @@ -374,8 +374,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
for (int c = 0; c < (int)exp->exponent_users.size(); c++) {
auto user = exp->exponent_users[c];
auto ch_id = snode->child_id(user);
auto digits = extract_digits_from_quant_float_with_shared_exponent(
floats[c], max_exp_bits);
auto digits =
extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits);
auto digits_snode = snode->ch[ch_id].get();
auto cft = digits_snode->dt->as<CustomFloatType>();
auto digits_bit_offset = digits_snode->bit_offset;
Expand Down Expand Up @@ -435,7 +435,7 @@ llvm::Value *CodeGenLLVM::extract_digits_from_f32(llvm::Value *f, bool full) {
return digits;
}

llvm::Value *CodeGenLLVM::extract_digits_from_quant_float_with_shared_exponent(
llvm::Value *CodeGenLLVM::extract_digits_from_f32_with_shared_exponent(
llvm::Value *f,
llvm::Value *shared_exp) {
auto exp = extract_exponent_from_f32(f);
Expand Down Expand Up @@ -518,13 +518,12 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
// Compute float(digits) * scale
llvm::Value *cast = nullptr;
auto compute_type = cft->get_compute_type()->as<PrimitiveType>();
if (cft->get_digits_type()->cast<CustomIntType>()->get_is_signed()) {
if (cft->get_is_signed()) {
cast = builder->CreateSIToFP(digits, llvm_type(compute_type));
} else {
cast = builder->CreateUIToFP(digits, llvm_type(compute_type));
}
llvm::Value *s =
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(cft->get_scale()));
llvm::Value *s = tlctx->get_constant(cft->get_scale());
s = builder->CreateFPCast(s, llvm_type(compute_type));
return builder->CreateFMul(cast, s);
}
Expand Down
6 changes: 2 additions & 4 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,8 @@ void AtomicOpExpression::type_check(CompileConfig *) {
};
if (!val->ret_type->is<PrimitiveType>())
error();
if (auto cit = dest->ret_type->cast<CustomIntType>()) {
ret_type = cit->get_compute_type();
} else if (auto cft = dest->ret_type->cast<CustomFloatType>()) {
ret_type = cft->get_compute_type();
if (is_quant(dest->ret_type)) {
ret_type = dest->ret_type->get_compute_type();
} else if (dest->ret_type->is<PrimitiveType>()) {
ret_type = dest->ret_type;
} else {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ inline PrimitiveTypeID get_primitive_data_type() {
}
}

inline bool is_custom_type(DataType dt) {
inline bool is_quant(DataType dt) {
return dt->is<CustomIntType>() || dt->is<CustomFloatType>();
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ void export_lang(py::module &m) {
#undef PER_TYPE

m.def("data_type_size", data_type_size);
m.def("is_custom_type", is_custom_type);
m.def("is_quant", is_quant);
m.def("is_integral", is_integral);
m.def("is_signed", is_signed);
m.def("is_real", is_real);
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TypeCheck : public IRVisitor {
Stmt *&val,
const std::string &stmt_name) {
auto dst_type = dst->ret_type.ptr_removed();
if (dst_type->is<CustomIntType>() || dst_type->is<CustomFloatType>()) {
if (is_quant(dst_type)) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_type = dst_type->get_compute_type();
Expand Down

0 comments on commit 6349d60

Please sign in to comment.