Skip to content

Commit

Permalink
[arm64] Add tan intrinsic lowering (#94545)
Browse files Browse the repository at this point in the history
This change is an implementation of
#87367 investigation on
supporting IEEE math operations as intrinsics.
Which was discussed in this RFC:
https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294

This PR is just for Tan.

Now that x86 tan backend landed:
#90503 we can add other
backends since the shared pieces are in tree now.

Changes:
- `llvm/include/llvm/Analysis/VecFuncs.def` - vectorization of tan for
arm64 backends.
- `llvm/lib/Target/AArch64/AArch64FastISel.cpp` - Add tan to the libcall
table
- `llvm/lib/Target/AArch64/AArch64ISelLowering.cpp` - Add tan expansion
for f128, f16, and vector\neon operations
- `llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp` define
`G_FTAN` as a legal arm64 instruction

resolves #94755
  • Loading branch information
farzonl authored Jun 7, 2024
1 parent ac02168 commit 2f0308e
Show file tree
Hide file tree
Showing 16 changed files with 704 additions and 41 deletions.
14 changes: 14 additions & 0 deletions llvm/include/llvm/Analysis/VecFuncs.def
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ TLI_DEFINE_VECFUNC("llvm.sin.f64", "_simd_sin_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("sinf", "_simd_sin_f4", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.sin.f32", "_simd_sin_f4", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("tan", "_simd_tan_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "_simd_tan_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("tanf", "_simd_tan_f4", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "_simd_tan_f4", FIXED(4), "_ZGV_LLVM_N4v")

// Floating-Point Arithmetic and Auxiliary Functions
TLI_DEFINE_VECFUNC("cbrt", "_simd_cbrt_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("cbrtf", "_simd_cbrt_f4", FIXED(4), "_ZGV_LLVM_N4v")
Expand Down Expand Up @@ -584,6 +589,7 @@ TLI_DEFINE_VECFUNC("sinpi", "_ZGVnN2v_sinpi", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("sqrt", "_ZGVnN2v_sqrt", FIXED(2), "_ZGV_LLVM_N2v")

TLI_DEFINE_VECFUNC("tan", "_ZGVnN2v_tan", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "_ZGVnN2v_tan", FIXED(2), "_ZGV_LLVM_N2v")

TLI_DEFINE_VECFUNC("tanh", "_ZGVnN2v_tanh", FIXED(2), "_ZGV_LLVM_N2v")

Expand Down Expand Up @@ -681,6 +687,7 @@ TLI_DEFINE_VECFUNC("sinpif", "_ZGVnN4v_sinpif", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("sqrtf", "_ZGVnN4v_sqrtf", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("tanf", "_ZGVnN4v_tanf", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "_ZGVnN4v_tanf", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("tanhf", "_ZGVnN4v_tanhf", FIXED(4), "_ZGV_LLVM_N4v")

Expand Down Expand Up @@ -828,6 +835,8 @@ TLI_DEFINE_VECFUNC("sqrtf", "_ZGVsMxv_sqrtf", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("tan", "_ZGVsMxv_tan", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("tanf", "_ZGVsMxv_tanf", SCALABLE(4), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "_ZGVsMxv_tan", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "_ZGVsMxv_tanf", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("tanh", "_ZGVsMxv_tanh", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("tanhf", "_ZGVsMxv_tanhf", SCALABLE(4), MASKED, "_ZGVsMxv")
Expand Down Expand Up @@ -1087,6 +1096,11 @@ TLI_DEFINE_VECFUNC("tanf", "armpl_vtanq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("tan", "armpl_svtan_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("tanf", "armpl_svtan_f32_x", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("llvm.tan.f64", "armpl_vtanq_f64", FIXED(2), NOMASK, "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "armpl_vtanq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "armpl_svtan_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "armpl_svtan_f32_x", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("tanh", "armpl_vtanhq_f64", FIXED(2), NOMASK, "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("tanhf", "armpl_vtanhq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("tanh", "armpl_svtanh_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
Expand Down
16 changes: 10 additions & 6 deletions llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3534,6 +3534,7 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
}
case Intrinsic::sin:
case Intrinsic::cos:
case Intrinsic::tan:
case Intrinsic::pow: {
MVT RetVT;
if (!isTypeLegal(II->getType(), RetVT))
Expand All @@ -3542,11 +3543,11 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
if (RetVT != MVT::f32 && RetVT != MVT::f64)
return false;

static const RTLIB::Libcall LibCallTable[3][2] = {
{ RTLIB::SIN_F32, RTLIB::SIN_F64 },
{ RTLIB::COS_F32, RTLIB::COS_F64 },
{ RTLIB::POW_F32, RTLIB::POW_F64 }
};
static const RTLIB::Libcall LibCallTable[4][2] = {
{RTLIB::SIN_F32, RTLIB::SIN_F64},
{RTLIB::COS_F32, RTLIB::COS_F64},
{RTLIB::TAN_F32, RTLIB::TAN_F64},
{RTLIB::POW_F32, RTLIB::POW_F64}};
RTLIB::Libcall LC;
bool Is64Bit = RetVT == MVT::f64;
switch (II->getIntrinsicID()) {
Expand All @@ -3558,9 +3559,12 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
case Intrinsic::cos:
LC = LibCallTable[1][Is64Bit];
break;
case Intrinsic::pow:
case Intrinsic::tan:
LC = LibCallTable[2][Is64Bit];
break;
case Intrinsic::pow:
LC = LibCallTable[3][Is64Bit];
break;
}

ArgListTy Args;
Expand Down
56 changes: 30 additions & 26 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FSINCOS, MVT::f128, Expand);
setOperationAction(ISD::FSQRT, MVT::f128, Expand);
setOperationAction(ISD::FSUB, MVT::f128, LibCall);
setOperationAction(ISD::FTAN, MVT::f128, Expand);
setOperationAction(ISD::FTRUNC, MVT::f128, Expand);
setOperationAction(ISD::SETCC, MVT::f128, Custom);
setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom);
Expand Down Expand Up @@ -727,14 +728,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
}

for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FEXP, ISD::FEXP2, ISD::FEXP10,
ISD::FLOG, ISD::FLOG2, ISD::FLOG10,
ISD::STRICT_FREM,
ISD::STRICT_FPOW, ISD::STRICT_FPOWI, ISD::STRICT_FCOS,
ISD::STRICT_FSIN, ISD::STRICT_FEXP, ISD::STRICT_FEXP2,
ISD::STRICT_FLOG, ISD::STRICT_FLOG2, ISD::STRICT_FLOG10}) {
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FTAN, ISD::FEXP, ISD::FEXP2,
ISD::FEXP10, ISD::FLOG, ISD::FLOG2,
ISD::FLOG10, ISD::STRICT_FREM, ISD::STRICT_FPOW,
ISD::STRICT_FPOWI, ISD::STRICT_FCOS, ISD::STRICT_FSIN,
ISD::STRICT_FEXP, ISD::STRICT_FEXP2, ISD::STRICT_FLOG,
ISD::STRICT_FLOG2, ISD::STRICT_FLOG10}) {
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::v4f16, Expand);
setOperationAction(Op, MVT::v8f16, Expand);
Expand Down Expand Up @@ -1171,26 +1172,27 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->isNeonAvailable()) {
// FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
// silliness like this:
// clang-format off
for (auto Op :
{ISD::SELECT, ISD::SELECT_CC,
ISD::BR_CC, ISD::FADD, ISD::FSUB,
ISD::FMUL, ISD::FDIV, ISD::FMA,
ISD::FNEG, ISD::FABS, ISD::FCEIL,
ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
ISD::FSIN, ISD::FCOS, ISD::FPOW,
ISD::FLOG, ISD::FLOG2, ISD::FLOG10,
ISD::FEXP, ISD::FEXP2, ISD::FEXP10,
ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN,
ISD::FTRUNC, ISD::FMINNUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMAXIMUM, ISD::STRICT_FADD,
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
ISD::STRICT_FMA, ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
ISD::STRICT_FSQRT, ISD::STRICT_FRINT, ISD::STRICT_FNEARBYINT,
ISD::STRICT_FROUND, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
ISD::STRICT_FMAXIMUM})
{ISD::SELECT, ISD::SELECT_CC,
ISD::BR_CC, ISD::FADD, ISD::FSUB,
ISD::FMUL, ISD::FDIV, ISD::FMA,
ISD::FNEG, ISD::FABS, ISD::FCEIL,
ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
ISD::FSIN, ISD::FCOS, ISD::FTAN,
ISD::FPOW, ISD::FLOG, ISD::FLOG2,
ISD::FLOG10, ISD::FEXP, ISD::FEXP2,
ISD::FEXP10, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN, ISD::FTRUNC, ISD::FMINNUM,
ISD::FMAXNUM, ISD::FMINIMUM, ISD::FMAXIMUM,
ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FCEIL,
ISD::STRICT_FFLOOR, ISD::STRICT_FSQRT, ISD::STRICT_FRINT,
ISD::STRICT_FNEARBYINT, ISD::STRICT_FROUND, ISD::STRICT_FTRUNC,
ISD::STRICT_FROUNDEVEN, ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM,
ISD::STRICT_FMINIMUM, ISD::STRICT_FMAXIMUM})
setOperationAction(Op, MVT::v1f64, Expand);

// clang-format on
for (auto Op :
{ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL,
Expand Down Expand Up @@ -1622,6 +1624,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOS, VT, Expand);
setOperationAction(ISD::FSIN, VT, Expand);
setOperationAction(ISD::FSINCOS, VT, Expand);
setOperationAction(ISD::FTAN, VT, Expand);
setOperationAction(ISD::FEXP, VT, Expand);
setOperationAction(ISD::FEXP2, VT, Expand);
setOperationAction(ISD::FEXP10, VT, Expand);
Expand Down Expand Up @@ -1803,6 +1806,7 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64) {
setOperationAction(ISD::FSIN, VT, Expand);
setOperationAction(ISD::FCOS, VT, Expand);
setOperationAction(ISD::FTAN, VT, Expand);
setOperationAction(ISD::FPOW, VT, Expand);
setOperationAction(ISD::FLOG, VT, Expand);
setOperationAction(ISD::FLOG2, VT, Expand);
Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.libcallFor({{s64, s128}})
.minScalarOrElt(1, MinFPScalar);

getActionDefinitionsBuilder(
{G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2, G_FLOG10,
G_FEXP, G_FEXP2, G_FEXP10})
getActionDefinitionsBuilder({G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2,
G_FLOG10, G_FTAN, G_FEXP, G_FEXP2, G_FEXP10})
// We need a call for these, so we always need to scalarize.
.scalarize(0)
// Regardless of FP16 support, widen 16-bit elements to 32-bits.
Expand Down
8 changes: 8 additions & 0 deletions llvm/test/CodeGen/AArch64/GlobalISel/arm64-irtranslator.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,14 @@ define float @test_sin_f32(float %x) {
ret float %y
}

declare float @llvm.tan.f32(float)
define float @test_tan_f32(float %x) {
; CHECK-LABEL: name: test_tan_f32
; CHECK: %{{[0-9]+}}:_(s32) = G_FTAN %{{[0-9]+}}
%y = call float @llvm.tan.f32(float %x)
ret float %y
}

declare float @llvm.sqrt.f32(float)
define float @test_sqrt_f32(float %x) {
; CHECK-LABEL: name: test_sqrt_f32
Expand Down
Loading

0 comments on commit 2f0308e

Please sign in to comment.