diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index f792d84535173..a02068fa32b70 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -663,14 +663,12 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { TI_NOT_IMPLEMENTED } } else { + // Note that ret_type here cannot be integral because pow with an + // integral exponent has been demoted in the demote_operations pass if (ret_type->is_primitive(PrimitiveTypeID::f32)) { llvm_val[stmt] = create_call("__nv_powf", {lhs, rhs}); } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = create_call("__nv_pow", {lhs, rhs}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = create_call("pow_i32", {lhs, rhs}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i64)) { - llvm_val[stmt] = create_call("pow_i64", {lhs, rhs}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 92c7d65141cbe..09027b174224b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -658,14 +658,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } else if (op == BinaryOpType::pow) { if (arch_is_cpu(current_arch())) { + // Note that ret_type here cannot be integral because pow with an + // integral exponent has been demoted in the demote_operations pass if (ret_type->is_primitive(PrimitiveTypeID::f32)) { llvm_val[stmt] = create_call("pow_f32", {lhs, rhs}); } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = create_call("pow_f64", {lhs, rhs}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = create_call("pow_i32", {lhs, rhs}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i64)) { - llvm_val[stmt] = create_call("pow_i64", {lhs, rhs}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp index 70a4265026f7a..3b3ceb875a309 100644 --- a/taichi/codegen/metal/codegen_metal.cpp +++ b/taichi/codegen/metal/codegen_metal.cpp @@ -559,12 +559,6 @@ class KernelCodegenImpl : public IRVisitor { } return; } - if (op_type == BinaryOpType::pow && is_integral(bin->ret_type)) { - // TODO(k-ye): Make sure the type is not i64? - emit("const {} {} = pow_i32({}, {});", dt_name, bin_name, lhs_name, - rhs_name); - return; - } const auto binop = metal_binary_op_type_symbol(op_type); if (is_metal_binary_op_infix(op_type)) { if (is_comparison(op_type)) { diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index fc3bcb648ec0f..85995d20fb490 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -846,35 +846,6 @@ class TaskCodegen : public IRVisitor { BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne) #undef BINARY_OP_TO_SPIRV_LOGICAL -#define INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, \ - instruction_id, max_bits) \ - else if (op_type == BinaryOpType::op) { \ - const uint32_t instruction = instruction_id; \ - if (is_real(bin->element_type()) || is_integral(bin->element_type())) { \ - if (data_type_bits(bin->element_type()) > max_bits) { \ - TI_ERROR( \ - "[glsl450] the operand type of instruction {}({}) must <= {}bits", \ - #instruction, instruction_id, max_bits); \ - } \ - if (is_integral(bin->element_type())) { \ - bin_value = ir_->cast( \ - dst_type, \ - ir_->add(ir_->call_glsl450(ir_->f32_type(), instruction, \ - ir_->cast(ir_->f32_type(), lhs_value), \ - ir_->cast(ir_->f32_type(), rhs_value)), \ - ir_->float_immediate_number(ir_->f32_type(), 0.5f))); \ - } else { \ - bin_value = \ - ir_->call_glsl450(dst_type, instruction, lhs_value, rhs_value); \ - } \ - } else { \ - TI_NOT_IMPLEMENTED \ - } \ - } - - INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32) -#undef INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC - #define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \ max_bits) \ else if (op_type == BinaryOpType::op) { \ @@ -893,6 +864,7 @@ class TaskCodegen : public IRVisitor { } FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(atan2, Atan2, 25, 32) + FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32) #undef FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC #define BINARY_OP_TO_SPIRV_FUNC(op, S_inst, S_inst_id, U_inst, U_inst_id, \ diff --git a/taichi/runtime/llvm/runtime_module/runtime.cpp b/taichi/runtime/llvm/runtime_module/runtime.cpp index 229a106959d3e..93023fb71eb1e 100644 --- a/taichi/runtime/llvm/runtime_module/runtime.cpp +++ b/taichi/runtime/llvm/runtime_module/runtime.cpp @@ -205,22 +205,6 @@ DEFINE_UNARY_REAL_FUNC(asin) DEFINE_UNARY_REAL_FUNC(cos) DEFINE_UNARY_REAL_FUNC(sin) -#define DEFINE_FAST_POW(T) \ - T pow_##T(T x, T n) { \ - T ans = 1; \ - T tmp = x; \ - while (n > 0) { \ - if (n & 1) \ - ans *= tmp; \ - tmp *= tmp; \ - n >>= 1; \ - } \ - return ans; \ - } - -DEFINE_FAST_POW(i32) -DEFINE_FAST_POW(i64) - i32 abs_i32(i32 a) { return a >= 0 ? a : -a; } diff --git a/taichi/runtime/metal/shaders/helpers.metal.h b/taichi/runtime/metal/shaders/helpers.metal.h index dacfc3e3c9eb2..5c8fd7dd1865e 100644 --- a/taichi/runtime/metal/shaders/helpers.metal.h +++ b/taichi/runtime/metal/shaders/helpers.metal.h @@ -38,18 +38,6 @@ STR( : intm); } - int32_t pow_i32(int32_t x, int32_t n) { - int32_t tmp = x; - int32_t ans = 1; - while (n > (int32_t)(0)) { - if (n & 1) - ans *= tmp; - tmp *= tmp; - n >>= 1; - } - return ans; - } - float fatomic_fetch_add(device float *dest, const float operand) { // A huge hack! Metal does not support atomic floating point numbers // natively.