Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[arm64] Add tan intrinsic lowering #94545

Merged
merged 4 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llvm/include/llvm/Analysis/VecFuncs.def
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ TLI_DEFINE_VECFUNC("llvm.sin.f64", "_simd_sin_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("sinf", "_simd_sin_f4", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.sin.f32", "_simd_sin_f4", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("tan", "_simd_tan_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "_simd_tan_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("tanf", "_simd_tan_f4", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "_simd_tan_f4", FIXED(4), "_ZGV_LLVM_N4v")

// Floating-Point Arithmetic and Auxiliary Functions
TLI_DEFINE_VECFUNC("cbrt", "_simd_cbrt_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("cbrtf", "_simd_cbrt_f4", FIXED(4), "_ZGV_LLVM_N4v")
Expand Down Expand Up @@ -584,6 +589,7 @@ TLI_DEFINE_VECFUNC("sinpi", "_ZGVnN2v_sinpi", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("sqrt", "_ZGVnN2v_sqrt", FIXED(2), "_ZGV_LLVM_N2v")

TLI_DEFINE_VECFUNC("tan", "_ZGVnN2v_tan", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "_ZGVnN2v_tan", FIXED(2), "_ZGV_LLVM_N2v")

TLI_DEFINE_VECFUNC("tanh", "_ZGVnN2v_tanh", FIXED(2), "_ZGV_LLVM_N2v")

Expand Down Expand Up @@ -681,6 +687,7 @@ TLI_DEFINE_VECFUNC("sinpif", "_ZGVnN4v_sinpif", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("sqrtf", "_ZGVnN4v_sqrtf", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("tanf", "_ZGVnN4v_tanf", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "_ZGVnN4v_tanf", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("tanhf", "_ZGVnN4v_tanhf", FIXED(4), "_ZGV_LLVM_N4v")

Expand Down Expand Up @@ -828,6 +835,8 @@ TLI_DEFINE_VECFUNC("sqrtf", "_ZGVsMxv_sqrtf", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("tan", "_ZGVsMxv_tan", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("tanf", "_ZGVsMxv_tanf", SCALABLE(4), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "_ZGVsMxv_tan", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "_ZGVsMxv_tanf", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("tanh", "_ZGVsMxv_tanh", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("tanhf", "_ZGVsMxv_tanhf", SCALABLE(4), MASKED, "_ZGVsMxv")
Expand Down Expand Up @@ -1087,6 +1096,11 @@ TLI_DEFINE_VECFUNC("tanf", "armpl_vtanq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("tan", "armpl_svtan_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("tanf", "armpl_svtan_f32_x", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("llvm.tan.f64", "armpl_vtanq_f64", FIXED(2), NOMASK, "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "armpl_vtanq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.tan.f64", "armpl_svtan_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("llvm.tan.f32", "armpl_svtan_f32_x", SCALABLE(4), MASKED, "_ZGVsMxv")

TLI_DEFINE_VECFUNC("tanh", "armpl_vtanhq_f64", FIXED(2), NOMASK, "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("tanhf", "armpl_vtanhq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("tanh", "armpl_svtanh_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
Expand Down
16 changes: 10 additions & 6 deletions llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3534,6 +3534,7 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
}
case Intrinsic::sin:
case Intrinsic::cos:
case Intrinsic::tan:
case Intrinsic::pow: {
MVT RetVT;
if (!isTypeLegal(II->getType(), RetVT))
Expand All @@ -3542,11 +3543,11 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
if (RetVT != MVT::f32 && RetVT != MVT::f64)
return false;

static const RTLIB::Libcall LibCallTable[3][2] = {
{ RTLIB::SIN_F32, RTLIB::SIN_F64 },
{ RTLIB::COS_F32, RTLIB::COS_F64 },
{ RTLIB::POW_F32, RTLIB::POW_F64 }
};
static const RTLIB::Libcall LibCallTable[4][2] = {
{RTLIB::SIN_F32, RTLIB::SIN_F64},
{RTLIB::COS_F32, RTLIB::COS_F64},
{RTLIB::TAN_F32, RTLIB::TAN_F64},
{RTLIB::POW_F32, RTLIB::POW_F64}};
RTLIB::Libcall LC;
bool Is64Bit = RetVT == MVT::f64;
switch (II->getIntrinsicID()) {
Expand All @@ -3558,9 +3559,12 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
case Intrinsic::cos:
LC = LibCallTable[1][Is64Bit];
break;
case Intrinsic::pow:
case Intrinsic::tan:
LC = LibCallTable[2][Is64Bit];
break;
case Intrinsic::pow:
LC = LibCallTable[3][Is64Bit];
break;
}

ArgListTy Args;
Expand Down
25 changes: 15 additions & 10 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FSINCOS, MVT::f128, Expand);
setOperationAction(ISD::FSQRT, MVT::f128, Expand);
setOperationAction(ISD::FSUB, MVT::f128, LibCall);
setOperationAction(ISD::FTAN, MVT::f128, Expand);
setOperationAction(ISD::FTRUNC, MVT::f128, Expand);
setOperationAction(ISD::SETCC, MVT::f128, Custom);
setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom);
Expand Down Expand Up @@ -727,14 +728,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
}

for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FEXP, ISD::FEXP2, ISD::FEXP10,
ISD::FLOG, ISD::FLOG2, ISD::FLOG10,
ISD::STRICT_FREM,
ISD::STRICT_FPOW, ISD::STRICT_FPOWI, ISD::STRICT_FCOS,
ISD::STRICT_FSIN, ISD::STRICT_FEXP, ISD::STRICT_FEXP2,
ISD::STRICT_FLOG, ISD::STRICT_FLOG2, ISD::STRICT_FLOG10}) {
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FTAN, ISD::FEXP, ISD::FEXP2,
ISD::FEXP10, ISD::FLOG, ISD::FLOG2,
ISD::FLOG10, ISD::STRICT_FREM, ISD::STRICT_FPOW,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a llvm.experimental.constrained.tan.f64 yet?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats in this PR: #94559

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see, great. And will that eventually involve adding testing for AArch64? Or is it best to add a quick STRICT_FTAN in here now? I know we have some tests for constrained.cos, and it doesn't look like much is needed to support them.

Copy link
Member Author

@farzonl farzonl Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks likere there are two places to add tests:

  • llvm/test/CodeGen/AArch64/fp-intrinsics-fp16.ll
  • llvm/test/CodeGen/AArch64/fp-intrinsics.ll

fp-intrinsics.ll I added now (ce4e2a0) to PR #94559 and everything works.

fp-intrinsics-fp16.ll will need to wait for this pr to merge before I add the tests because f16 is handled on a per target basis.

tan-aarch64-f16.patch.txt

ISD::STRICT_FPOWI, ISD::STRICT_FCOS, ISD::STRICT_FSIN,
ISD::STRICT_FEXP, ISD::STRICT_FEXP2, ISD::STRICT_FLOG,
ISD::STRICT_FLOG2, ISD::STRICT_FLOG10}) {
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::v4f16, Expand);
setOperationAction(Op, MVT::v8f16, Expand);
Expand Down Expand Up @@ -1171,13 +1172,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->isNeonAvailable()) {
// FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
// silliness like this:
// clang-format off
for (auto Op :
{ISD::SELECT, ISD::SELECT_CC,
ISD::BR_CC, ISD::FADD, ISD::FSUB,
ISD::FMUL, ISD::FDIV, ISD::FMA,
ISD::FNEG, ISD::FABS, ISD::FCEIL,
ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
ISD::FSIN, ISD::FCOS, ISD::FPOW,
ISD::FSIN, ISD::FCOS, ISD::FTAN,
ISD::FPOW,
farzonl marked this conversation as resolved.
Show resolved Hide resolved
ISD::FLOG, ISD::FLOG2, ISD::FLOG10,
ISD::FEXP, ISD::FEXP2, ISD::FEXP10,
ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN,
Expand All @@ -1190,7 +1193,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
ISD::STRICT_FMAXIMUM})
setOperationAction(Op, MVT::v1f64, Expand);

// clang-format on
for (auto Op :
{ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL,
Expand Down Expand Up @@ -1622,6 +1625,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOS, VT, Expand);
setOperationAction(ISD::FSIN, VT, Expand);
setOperationAction(ISD::FSINCOS, VT, Expand);
setOperationAction(ISD::FTAN, VT, Expand);
setOperationAction(ISD::FEXP, VT, Expand);
setOperationAction(ISD::FEXP2, VT, Expand);
setOperationAction(ISD::FEXP10, VT, Expand);
Expand Down Expand Up @@ -1803,6 +1807,7 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64) {
setOperationAction(ISD::FSIN, VT, Expand);
setOperationAction(ISD::FCOS, VT, Expand);
setOperationAction(ISD::FTAN, VT, Expand);
setOperationAction(ISD::FPOW, VT, Expand);
setOperationAction(ISD::FLOG, VT, Expand);
setOperationAction(ISD::FLOG2, VT, Expand);
Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.libcallFor({{s64, s128}})
.minScalarOrElt(1, MinFPScalar);

getActionDefinitionsBuilder(
{G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2, G_FLOG10,
G_FEXP, G_FEXP2, G_FEXP10})
getActionDefinitionsBuilder({G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2,
G_FLOG10, G_FTAN, G_FEXP, G_FEXP2, G_FEXP10})
// We need a call for these, so we always need to scalarize.
.scalarize(0)
// Regardless of FP16 support, widen 16-bit elements to 32-bits.
Expand Down
8 changes: 8 additions & 0 deletions llvm/test/CodeGen/AArch64/GlobalISel/arm64-irtranslator.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,14 @@ define float @test_sin_f32(float %x) {
ret float %y
}

declare float @llvm.tan.f32(float)
define float @test_tan_f32(float %x) {
; CHECK-LABEL: name: test_tan_f32
; CHECK: %{{[0-9]+}}:_(s32) = G_FTAN %{{[0-9]+}}
%y = call float @llvm.tan.f32(float %x)
ret float %y
}

declare float @llvm.sqrt.f32(float)
define float @test_sqrt_f32(float %x) {
; CHECK-LABEL: name: test_sqrt_f32
Expand Down
Loading
Loading