From 94fc09c7c9dd31dee5062cf5ac484b527e685af8 Mon Sep 17 00:00:00 2001 From: PaulCarabas Date: Tue, 4 Feb 2025 21:24:19 +0200 Subject: [PATCH] [mlir][LLVMIR] Add support for tan intrinsic op --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 7 +++++-- mlir/test/Target/LLVMIR/Import/intrinsic.ll | 19 +++++++++++++++---- .../test/Target/LLVMIR/llvmir-intrinsics.mlir | 18 ++++++++++++++---- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index a7d683438ee8a..72fae1bdbf461 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -107,7 +107,6 @@ def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure], } def LLVM_CopySignOp : LLVM_BinarySameArgsIntrOpF<"copysign">; -def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">; def LLVM_ExpOp : LLVM_UnaryIntrOpF<"exp">; def LLVM_Exp2Op : LLVM_UnaryIntrOpF<"exp2">; def LLVM_FAbsOp : LLVM_UnaryIntrOpF<"fabs">; @@ -125,7 +124,6 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0], > { let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache); } -def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">; def LLVM_RoundEvenOp : LLVM_UnaryIntrOpF<"roundeven">; def LLVM_RoundOp : LLVM_UnaryIntrOpF<"round">; def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">; @@ -167,6 +165,11 @@ def LLVM_SMaxOp : LLVM_BinarySameArgsIntrOpI<"smax">; def LLVM_SMinOp : LLVM_BinarySameArgsIntrOpI<"smin">; def LLVM_UMaxOp : LLVM_BinarySameArgsIntrOpI<"umax">; def LLVM_UMinOp : LLVM_BinarySameArgsIntrOpI<"umin">; + +def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">; +def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">; +def LLVM_TanOp : LLVM_UnaryIntrOpF<"tan">; + def LLVM_SinhOp : LLVM_UnaryIntrOpF<"sinh">; def LLVM_CoshOp : LLVM_UnaryIntrOpF<"cosh">; def LLVM_TanhOp : LLVM_UnaryIntrOpF<"tanh">; diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index bd335323a2e1c..249a0552c87f3 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -101,12 +101,23 @@ define void @floor_test(float %0, <8 x float> %1) { %4 = call <8 x float> @llvm.floor.v8f32(<8 x float> %1) ret void } -; CHECK-LABEL: llvm.func @cos_test -define void @cos_test(float %0, <8 x float> %1) { +; CHECK-LABEL: llvm.func @trig_test +define void @trig_test(float %0, <8 x float> %1) { + ; CHECK: llvm.intr.sin(%{{.*}}) : (f32) -> f32 + %3 = call float @llvm.sin.f32(float %0) + ; CHECK: llvm.intr.sin(%{{.*}}) : (vector<8xf32>) -> vector<8xf32> + %4 = call <8 x float> @llvm.sin.v8f32(<8 x float> %1) + ; CHECK: llvm.intr.cos(%{{.*}}) : (f32) -> f32 - %3 = call float @llvm.cos.f32(float %0) + %5 = call float @llvm.cos.f32(float %0) ; CHECK: llvm.intr.cos(%{{.*}}) : (vector<8xf32>) -> vector<8xf32> - %4 = call <8 x float> @llvm.cos.v8f32(<8 x float> %1) + %6 = call <8 x float> @llvm.cos.v8f32(<8 x float> %1) + + ; CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32 + %7 = call float @llvm.tan.f32(float %0) + ; CHECK: llvm.intr.tan(%{{.*}}) : (vector<8xf32>) -> vector<8xf32> + %8 = call <8 x float> @llvm.tan.v8f32(<8 x float> %1) + ret void } ; CHECK-LABEL: llvm.func @hyperbolic_trig_test diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 382b2b9f3cd73..2c208789e36dd 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -103,12 +103,22 @@ llvm.func @floor_test(%arg0: f32, %arg1: vector<8xf32>) { llvm.return } -// CHECK-LABEL: @cos_test -llvm.func @cos_test(%arg0: f32, %arg1: vector<8xf32>) { +// CHECK-LABEL: @trig_test +llvm.func @trig_test(%arg0: f32, %arg1: vector<8xf32>) { + // CHECK: call float @llvm.sin.f32 + llvm.intr.sin(%arg0) : (f32) -> f32 + // CHECK: call <8 x float> @llvm.sin.v8f32 + llvm.intr.sin(%arg1) : (vector<8xf32>) -> vector<8xf32> + // CHECK: call float @llvm.cos.f32 - "llvm.intr.cos"(%arg0) : (f32) -> f32 + llvm.intr.cos(%arg0) : (f32) -> f32 // CHECK: call <8 x float> @llvm.cos.v8f32 - "llvm.intr.cos"(%arg1) : (vector<8xf32>) -> vector<8xf32> + llvm.intr.cos(%arg1) : (vector<8xf32>) -> vector<8xf32> + + // CHECK: call float @llvm.tan.f32 + llvm.intr.tan(%arg0) : (f32) -> f32 + // CHECK: call <8 x float> @llvm.tan.v8f32 + llvm.intr.tan(%arg1) : (vector<8xf32>) -> vector<8xf32> llvm.return }